diff --git a/.buildkite/ci_config.yaml b/.buildkite/ci_config.yaml new file mode 100644 index 000000000000..199c33159fde --- /dev/null +++ b/.buildkite/ci_config.yaml @@ -0,0 +1,24 @@ +name: vllm_ci +job_dirs: + - ".buildkite/test_areas" + - ".buildkite/image_build" +run_all_patterns: + - "docker/Dockerfile" + - "CMakeLists.txt" + - "requirements/common.txt" + - "requirements/cuda.txt" + - "requirements/build.txt" + - "requirements/test.txt" + - "setup.py" + - "csrc/" + - "cmake/" +run_all_exclude_patterns: + - "docker/Dockerfile." + - "csrc/cpu/" + - "csrc/rocm/" + - "cmake/hipify.py" + - "cmake/cpu_extension.cmake" +registries: public.ecr.aws/q9t5s3a7 +repositories: + main: "vllm-ci-postmerge-repo" + premerge: "vllm-ci-test-repo" diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py deleted file mode 100644 index bbed80ebe847..000000000000 --- a/.buildkite/generate_index.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import os - -template = """ - - -

Links for vLLM

- {x86_wheel}
- {arm_wheel}
- - -""" - -parser = argparse.ArgumentParser() -parser.add_argument("--wheel", help="The wheel path.", required=True) -args = parser.parse_args() - -filename = os.path.basename(args.wheel) - -with open("index.html", "w") as f: - print(f"Generated index.html for {args.wheel}") - # sync the abi tag with .buildkite/scripts/upload-wheels.sh - if "x86_64" in filename: - x86_wheel = filename - arm_wheel = filename.replace("x86_64", "aarch64").replace( - "manylinux1", "manylinux2014" - ) - elif "aarch64" in filename: - x86_wheel = filename.replace("aarch64", "x86_64").replace( - "manylinux2014", "manylinux1" - ) - arm_wheel = filename - else: - raise ValueError(f"Unsupported wheel: {filename}") - # cloudfront requires escaping the '+' character - f.write( - template.format( - x86_wheel=x86_wheel, - x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), - arm_wheel=arm_wheel, - arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), - ) - ) diff --git a/.buildkite/image_build/image_build.sh b/.buildkite/image_build/image_build.sh new file mode 100755 index 000000000000..9a2384e524b6 --- /dev/null +++ b/.buildkite/image_build/image_build.sh @@ -0,0 +1,56 @@ +#!/bin/bash +set -e + +if [[ $# -lt 8 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 +BRANCH=$4 +VLLM_USE_PRECOMPILED=$5 +VLLM_MERGE_BASE_COMMIT=$6 +CACHE_FROM=$7 +CACHE_TO=$8 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY +aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com + +# docker buildx +docker buildx create --name vllm-builder --driver docker-container --use +docker buildx inspect --bootstrap +docker buildx ls + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +if [[ "${VLLM_USE_PRECOMPILED:-0}" == "1" ]]; then + merge_base_commit_build_args="--build-arg VLLM_MERGE_BASE_COMMIT=${VLLM_MERGE_BASE_COMMIT}" +else + merge_base_commit_build_args="" +fi + +# build +docker buildx build --file docker/Dockerfile \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --build-arg USE_SCCACHE=1 \ + --build-arg TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0 10.0" \ + --build-arg FI_TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0a 10.0a" \ + --build-arg VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED:-0}" \ + ${merge_base_commit_build_args} \ + --cache-from type=registry,ref=${CACHE_FROM},mode=max \ + --cache-to type=registry,ref=${CACHE_TO},mode=max \ + --tag ${REGISTRY}/${REPO}:${BUILDKITE_COMMIT} \ + $( [[ "${BRANCH}" == "main" ]] && echo "--tag ${REGISTRY}/${REPO}:latest" ) \ + --push \ + --target test \ + --progress plain . diff --git a/.buildkite/image_build/image_build.yaml b/.buildkite/image_build/image_build.yaml new file mode 100644 index 000000000000..d01c71dd9bec --- /dev/null +++ b/.buildkite/image_build/image_build.yaml @@ -0,0 +1,57 @@ +group: Abuild +steps: + - label: ":docker: Build image" + key: image-build + depends_on: [] + commands: + - .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $CACHE_FROM $CACHE_TO + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 + + - label: ":docker: Build CPU image" + key: image-build-cpu + depends_on: [] + commands: + - .buildkite/image_build/image_build_cpu.sh $REGISTRY $REPO $BUILDKITE_COMMIT + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 + + - label: ":docker: Build HPU image" + soft_fail: true + depends_on: [] + key: image-build-hpu + commands: + - .buildkite/image_build/image_build_hpu.sh $REGISTRY $REPO $BUILDKITE_COMMIT + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 + + - label: ":docker: Build CPU arm64 image" + key: cpu-arm64-image-build + depends_on: [] + optional: true + commands: + - .buildkite/image_build/image_build_cpu_arm64.sh $REGISTRY $REPO $BUILDKITE_COMMIT + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 diff --git a/.buildkite/image_build/image_build_cpu.sh b/.buildkite/image_build/image_build_cpu.sh new file mode 100755 index 000000000000..a69732f43098 --- /dev/null +++ b/.buildkite/image_build/image_build_cpu.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -e + +if [[ $# -lt 3 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +# build +docker build --file docker/Dockerfile.cpu \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --build-arg VLLM_CPU_AVX512BF16=true \ + --build-arg VLLM_CPU_AVX512VNNI=true \ + --build-arg VLLM_CPU_AMXBF16=true \ + --tag $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu \ + --target vllm-test \ + --progress plain . + +# push +docker push $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu diff --git a/.buildkite/image_build/image_build_cpu_arm64.sh b/.buildkite/image_build/image_build_cpu_arm64.sh new file mode 100755 index 000000000000..615298b6555b --- /dev/null +++ b/.buildkite/image_build/image_build_cpu_arm64.sh @@ -0,0 +1,33 @@ +#!/bin/bash +set -e + +if [[ $# -lt 3 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +# build +docker build --file docker/Dockerfile.cpu \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --tag $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu \ + --target vllm-test \ + --progress plain . + +# push +docker push $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu diff --git a/.buildkite/image_build/image_build_hpu.sh b/.buildkite/image_build/image_build_hpu.sh new file mode 100755 index 000000000000..192447ef4577 --- /dev/null +++ b/.buildkite/image_build/image_build_hpu.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e + +if [[ $# -lt 3 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT-hpu) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +# build +docker build \ + --file tests/pytorch_ci_hud_benchmark/Dockerfile.hpu \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --tag $REGISTRY/$REPO:$BUILDKITE_COMMIT-hpu \ + --progress plain \ + https://github.com/vllm-project/vllm-gaudi.git + +# push +docker push $REGISTRY/$REPO:$BUILDKITE_COMMIT-hpu diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml index 46f1a9fbf6ff..6c0b5540cbb6 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml @@ -8,3 +8,4 @@ tasks: value: 0.80 limit: 250 # will run on 250 * 14 subjects = 3500 samples num_fewshot: 5 +rtol: 0.05 diff --git a/.buildkite/lm-eval-harness/configs/models-large-rocm.txt b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt new file mode 100644 index 000000000000..4fb0b84bc4d8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt @@ -0,0 +1 @@ +Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 3627b760eddc..f94d681197d2 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -9,11 +9,40 @@ --tp-size=1 """ +import os +from contextlib import contextmanager + import lm_eval import numpy as np import yaml -RTOL = 0.08 +DEFAULT_RTOL = 0.08 + + +@contextmanager +def scoped_env_vars(new_env: dict[str, str]): + if not new_env: + # Fast path: nothing to do + yield + return + + old_values = {} + new_keys = [] + + try: + for key, value in new_env.items(): + if key in os.environ: + old_values[key] = os.environ[key] + else: + new_keys.append(key) + os.environ[key] = str(value) + yield + finally: + # Restore / clean up + for key, value in old_values.items(): + os.environ[key] = value + for key in new_keys: + os.environ.pop(key, None) def launch_lm_eval(eval_config, tp_size): @@ -32,23 +61,26 @@ def launch_lm_eval(eval_config, tp_size): f"trust_remote_code={trust_remote_code}," f"max_model_len={max_model_len}," ) - results = lm_eval.simple_evaluate( - model=backend, - model_args=model_args, - tasks=[task["name"] for task in eval_config["tasks"]], - num_fewshot=eval_config["num_fewshot"], - limit=eval_config["limit"], - # TODO(yeq): using chat template w/ fewshot_as_multiturn is supposed help - # text models. however, this is regressing measured strict-match for - # existing text models in CI, so only apply it for mm, or explicitly set - apply_chat_template=eval_config.get( - "apply_chat_template", backend == "vllm-vlm" - ), - fewshot_as_multiturn=eval_config.get("fewshot_as_multiturn", False), - # Forward decoding and early-stop controls (e.g., max_gen_toks, until=...) - gen_kwargs=eval_config.get("gen_kwargs"), - batch_size=batch_size, - ) + + env_vars = eval_config.get("env_vars", None) + with scoped_env_vars(env_vars): + results = lm_eval.simple_evaluate( + model=backend, + model_args=model_args, + tasks=[task["name"] for task in eval_config["tasks"]], + num_fewshot=eval_config["num_fewshot"], + limit=eval_config["limit"], + # TODO(yeq): using chat template w/ fewshot_as_multiturn is supposed help + # text models. however, this is regressing measured strict-match for + # existing text models in CI, so only apply it for mm, or explicitly set + apply_chat_template=eval_config.get( + "apply_chat_template", backend == "vllm-vlm" + ), + fewshot_as_multiturn=eval_config.get("fewshot_as_multiturn", False), + # Forward decoding and early-stop controls (e.g., max_gen_toks, until=...) + gen_kwargs=eval_config.get("gen_kwargs"), + batch_size=batch_size, + ) return results @@ -57,6 +89,8 @@ def test_lm_eval_correctness_param(config_filename, tp_size): results = launch_lm_eval(eval_config, tp_size) + rtol = eval_config.get("rtol", DEFAULT_RTOL) + success = True for task in eval_config["tasks"]: for metric in task["metrics"]: @@ -64,8 +98,9 @@ def test_lm_eval_correctness_param(config_filename, tp_size): measured_value = results["results"][task["name"]][metric["name"]] print( f"{task['name']} | {metric['name']}: " - f"ground_truth={ground_truth} | measured={measured_value}" + f"ground_truth={ground_truth:.3f} | " + f"measured={measured_value:.3f} | rtol={rtol}" ) - success = success and np.isclose(ground_truth, measured_value, rtol=RTOL) + success = success and np.isclose(ground_truth, measured_value, rtol=rtol) assert success diff --git a/.buildkite/performance-benchmarks/README.md b/.buildkite/performance-benchmarks/README.md index 6d494f64f14f..015f48c2520d 100644 --- a/.buildkite/performance-benchmarks/README.md +++ b/.buildkite/performance-benchmarks/README.md @@ -108,6 +108,65 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. +#### Default Parameters Field + +We can specify default parameters in a JSON field with key `defaults`. Parameters defined in the field are applied globally to all serving tests, and can be overridden in test case fields. Here is an example: + +
+ An Example of default parameters field + +```json +{ + "defaults": { + "qps_list": [ + "inf" + ], + "server_environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1 + }, + "server_parameters": { + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "block_size": 128, + "disable_log_stats": "", + "load_format": "dummy" + }, + "client_parameters": { + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "num_prompts": 200, + "ignore-eos": "" + } + }, + "tests": [ + { + "test_name": "serving_llama3B_tp2_random_128_128", + "server_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tensor_parallel_size": 2, + }, + "client_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + } + }, + { + "test_name": "serving_qwen3_tp4_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-14B", + "tensor_parallel_size": 4, + }, + "client_parameters": { + "model": "Qwen/Qwen3-14B", + } + }, + ] +} +``` + +
+ ### Visualizing the results The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](performance-benchmarks-descriptions.md) with real benchmarking results. diff --git a/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh index 99a5a5e334f8..34ceefe0996f 100644 --- a/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh @@ -110,7 +110,8 @@ json2envs() { wait_for_server() { # wait for vllm server to start # return 1 if vllm server crashes - timeout 1200 bash -c ' + local timeout_val="1200" + timeout "$timeout_val" bash -c ' until curl -X POST localhost:8000/v1/completions; do sleep 1 done' && return 0 || return 1 @@ -316,12 +317,44 @@ run_throughput_tests() { run_serving_tests() { # run serving tests using `vllm bench serve` command # $1: a json file specifying serving test cases + # + # Supported JSON formats: + # 1) Plain format: top-level array + # [ { "test_name": "...", "server_parameters": {...}, ... }, ... ] + # + # 2) Default parameters field + plain format tests + # { + # "defaults": { ... }, + # "tests": [ { "test_name": "...", "server_parameters": {...}, ... }, ... ] + # } local serving_test_file serving_test_file=$1 # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do + jq -c ' + if type == "array" then + # Plain format: test cases array + .[] + elif (type == "object" and has("tests")) then + # merge the default parameters into each test cases + . as $root + | ($root.defaults // {}) as $d + | ($root.tests // [])[] + # default qps / max_concurrency from defaults if missing + | .qps_list = (.qps_list // $d.qps_list) + | .max_concurrency_list = (.max_concurrency_list // $d.max_concurrency_list) + # merge envs / params: test overrides defaults + | .server_environment_variables = + (($d.server_environment_variables // {}) + (.server_environment_variables // {})) + | .server_parameters = + (($d.server_parameters // {}) + (.server_parameters // {})) + | .client_parameters = + (($d.client_parameters // {}) + (.client_parameters // {})) + else + error("Unsupported serving test file format: must be array or object with .tests") + end + ' "$serving_test_file" | while read -r params; do # get the test name, and append the GPU type back to it. test_name=$(echo "$params" | jq -r '.test_name') if [[ ! "$test_name" =~ ^serving_ ]]; then @@ -335,20 +368,25 @@ run_serving_tests() { continue fi - # get client and server arguments + # get client and server arguments (after merged the default parameters) server_params=$(echo "$params" | jq -r '.server_parameters') server_envs=$(echo "$params" | jq -r '.server_environment_variables') client_params=$(echo "$params" | jq -r '.client_parameters') + server_args=$(json2args "$server_params") server_envs=$(json2envs "$server_envs") client_args=$(json2args "$client_params") + + # qps_list qps_list=$(echo "$params" | jq -r '.qps_list') qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') echo "Running over qps list $qps_list" + + # max_concurrency_list (fallback to num_prompts if missing) max_concurrency_list=$(echo "$params" | jq -r '.max_concurrency_list') if [[ -z "$max_concurrency_list" || "$max_concurrency_list" == "null" ]]; then - num_prompts=$(echo "$client_params" | jq -r '.num_prompts') - max_concurrency_list="[$num_prompts]" + num_prompts=$(echo "$client_params" | jq -r '.num_prompts') + max_concurrency_list="[$num_prompts]" fi max_concurrency_list=$(echo "$max_concurrency_list" | jq -r '.[] | @sh') echo "Running over max concurrency list $max_concurrency_list" diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json deleted file mode 100644 index f758097e098e..000000000000 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json +++ /dev/null @@ -1,610 +0,0 @@ -[ - { - "test_name": "serving_llama8B_bf16_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - } -] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json deleted file mode 100644 index 0b1a42e79025..000000000000 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json +++ /dev/null @@ -1,1023 +0,0 @@ -[ - { - "test_name": "serving_llama8B_bf16_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - } -] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json index f792956f3947..8f7200862d20 100644 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json +++ b/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json @@ -1,276 +1,246 @@ -[ +{ + "defaults": { + "qps_list": [ + "inf" + ], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "ignore-eos": "", + "num_prompts": 200 + } + }, + "tests": [ + { + "test_name": "serving_llama8B_tp1_sharegpt", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json" + } + }, + { + "test_name": "serving_llama8B_tp2_sharegpt", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json" + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_128", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_128", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp4_random_128_128", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp4_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp1_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp2_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, { - "test_name": "serving_llama8B_tp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 32 - } + "test_name": "serving_llama8B_tp4_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp2_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 32 - } + "test_name": "serving_llama3B_tp1_random_128_128", + "server_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp1_random_128_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } + "test_name": "serving_granite2B_tp1_random_128_128", + "server_parameters": { + "model": "ibm-granite/granite-3.2-2b-instruct", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "ibm-granite/granite-3.2-2b-instruct", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp2_random_128_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } + "test_name": "serving_qwen1.7B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-1.7B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-1.7B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp1_random_128_2048", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 2048, - "ignore-eos": "", - "num_prompts": 32 - } + "test_name": "serving_qwen4B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-4B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-4B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp2_random_128_2048", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 2048, - "ignore-eos": "", - "num_prompts": 32 - } + "test_name": "serving_qwen8B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-8B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-8B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp1_random_2048_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 2048, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } + "test_name": "serving_glm9B_tp1_random_128_128", + "server_parameters": { + "model": "zai-org/glm-4-9b-hf", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "zai-org/glm-4-9b-hf", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } }, { - "test_name": "serving_llama8B_tp2_random_2048_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 2048, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } + "test_name": "serving_gemma7B_tp1_random_128_128", + "server_parameters": { + "model": "google/gemma-7b", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "google/gemma-7b", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } } -] + ] +} diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 38c400ba1faf..a9d51557bd9b 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -8,41 +8,43 @@ steps: commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - # aarch64 build - - label: "Build arm64 CPU wheel" + - label: "Build arm64 wheel - CUDA 13.0" depends_on: ~ - id: build-wheel-arm64-cpu + id: build-wheel-arm64-cuda-13-0 agents: queue: arm64_cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ." + # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: + # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" env: DOCKER_BUILDKIT: "1" - # x86 + CUDA builds - - label: "Build wheel - CUDA 12.8" + # aarch64 build + - label: "Build arm64 CPU wheel" depends_on: ~ - id: build-wheel-cuda-12-8 + id: build-wheel-arm64-cpu agents: - queue: cpu_queue_postmerge + queue: arm64_cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" env: DOCKER_BUILDKIT: "1" + # x86 + CUDA builds - label: "Build wheel - CUDA 12.9" depends_on: ~ id: build-wheel-cuda-12-9 @@ -52,7 +54,7 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31" env: DOCKER_BUILDKIT: "1" @@ -65,7 +67,21 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" + env: + DOCKER_BUILDKIT: "1" + + # x86 CPU wheel build + - label: "Build x86 CPU wheel" + depends_on: ~ + id: build-wheel-x86-cpu + agents: + queue: cpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" env: DOCKER_BUILDKIT: "1" @@ -109,7 +125,6 @@ steps: - label: "Annotate release workflow" depends_on: - create-multi-arch-manifest - - build-wheel-cuda-12-8 id: annotate-release-workflow agents: queue: cpu_queue_postmerge diff --git a/.buildkite/scripts/generate-nightly-index.py b/.buildkite/scripts/generate-nightly-index.py new file mode 100644 index 000000000000..d0965fbd5640 --- /dev/null +++ b/.buildkite/scripts/generate-nightly-index.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# do not complain about line length (for docstring) +# ruff: noqa: E501 + +import argparse +import json +import sys +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any +from urllib.parse import quote + +import regex as re + +if not sys.version_info >= (3, 12): + raise RuntimeError("This script requires Python 3.12 or higher.") + +INDEX_HTML_TEMPLATE = """ + + + + +{items} + + +""" + + +@dataclass +class WheelFileInfo: + package_name: str + version: str + build_tag: str | None + python_tag: str + abi_tag: str + platform_tag: str + variant: str | None + filename: str + + +def parse_from_filename(file: str) -> WheelFileInfo: + """ + Parse wheel file name to extract metadata. + + The format of wheel names: + {package_name}-{version}(-{build_tag})?-{python_tag}-{abi_tag}-{platform_tag}.whl + All versions could contain a variant like '+cu129' or '.cpu' or `.rocm` (or not). + Example: + vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl + vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl + vllm-0.11.1rc8.dev14+gaa384b3c0-cp38-abi3-manylinux2014_aarch64.whl + vllm-0.11.1rc8.dev14+gaa384b3c0.cu130-cp38-abi3-manylinux1_x86_64.whl + """ + wheel_file_re = re.compile( + r"^(?P.+)-(?P[^-]+?)(-(?P[^-]+))?-(?P[^-]+)-(?P[^-]+)-(?P[^-]+)\.whl$" + ) + match = wheel_file_re.match(file) + if not match: + raise ValueError(f"Invalid wheel file name: {file}") + + package_name = match.group("package_name") + version = match.group("version") + build_tag = match.group("build_tag") + python_tag = match.group("python_tag") + abi_tag = match.group("abi_tag") + platform_tag = match.group("platform_tag") + + # extract variant from version + variant = None + if "dev" in version: + ver_after_dev = version.split("dev")[-1] + if "." in ver_after_dev: + variant = ver_after_dev.split(".")[-1] + version = version.removesuffix("." + variant) + else: + if "+" in version: + version, variant = version.split("+") + + return WheelFileInfo( + package_name=package_name, + version=version, + build_tag=build_tag, + python_tag=python_tag, + abi_tag=abi_tag, + platform_tag=platform_tag, + variant=variant, + filename=file, + ) + + +def generate_project_list(subdir_names: list[str], comment: str = "") -> str: + """ + Generate project list HTML content linking to each project & variant sub-directory. + """ + href_tags = [] + for name in sorted(subdir_names): + name = name.strip("/").strip(".") + href_tags.append(f' {name}/
') + return INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags), comment=comment) + + +def generate_package_index_and_metadata( + wheel_files: list[WheelFileInfo], + wheel_base_dir: Path, + index_base_dir: Path, + comment: str = "", +) -> tuple[str, str]: + """ + Generate package index HTML content for a specific package, linking to actual wheel files. + """ + href_tags = [] + metadata = [] + for file in sorted(wheel_files, key=lambda x: x.filename): + relative_path = ( + wheel_base_dir.relative_to(index_base_dir, walk_up=True) / file.filename + ) + # handle with '+' in URL, and avoid double-encoding '/' and already-encoded '%2B' + # NOTE: this is AWS S3 specific behavior! + file_path_quoted = quote(relative_path.as_posix(), safe=":%/") + href_tags.append(f' {file.filename}
') + file_meta = asdict(file) + file_meta["path"] = file_path_quoted + metadata.append(file_meta) + index_str = INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags), comment=comment) + metadata_str = json.dumps(metadata, indent=2) + return index_str, metadata_str + + +def generate_index_and_metadata( + whl_files: list[str], + wheel_base_dir: Path, + index_base_dir: Path, + default_variant: str | None = None, + alias_to_default: str | None = None, + comment: str = "", +): + """ + Generate index for all wheel files. + + Args: + whl_files (list[str]): List of wheel files (must be directly under `wheel_base_dir`). + wheel_base_dir (Path): Base directory for wheel files. + index_base_dir (Path): Base directory to store index files. + default_variant (str | None): The default variant name, if any. + alias_to_default (str | None): Alias variant name for the default variant, if any. + comment (str | None): Optional comment to include in the generated HTML files. + + First, parse all wheel files to extract metadata. + We need to collect all wheel files for each variant, and generate an index for it (in a sub-directory). + The index for the default variant (if any) is generated in the root index directory. + + If `default_variant` is provided, all wheels must have variant suffixes, and the default variant index + is purely a copy of the corresponding variant index, with only the links adjusted. + Otherwise, all wheels without variant suffixes are treated as the default variant. + + If `alias_to_default` is provided, an additional alias sub-directory is created, it has the same content + as the default variant index, but the links are adjusted accordingly. + + Index directory structure: + index_base_dir/ (hosted at wheels.vllm.ai/{nightly,$commit,$version}/) + index.html # project list, linking to "vllm/" and other packages, and all variant sub-directories + vllm/ + index.html # package index, pointing to actual files in wheel_base_dir (relative path) + metadata.json # machine-readable metadata for all wheels in this package + cpu/ # cpu variant sub-directory + index.html + vllm/ + index.html + metadata.json + cu129/ # cu129 is actually the alias to default variant + index.html + vllm/ + index.html + metadata.json + cu130/ # cu130 variant sub-directory + index.html + vllm/ + index.html + metadata.json + ... + + metadata.json stores a dump of all wheel files' metadata in a machine-readable format: + [ + { + "package_name": "vllm", + "version": "0.10.2rc2", + "build_tag": null, + "python_tag": "cp38", + "abi_tag": "abi3", + "platform_tag": "manylinux2014_aarch64", + "variant": "cu129", + "filename": "vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl", + "path": "../vllm-0.10.2rc2%2Bcu129-cp38-abi3-manylinux2014_aarch64.whl" # to be concatenated with the directory URL and URL-encoded + }, + ... + ] + """ + + parsed_files = [parse_from_filename(f) for f in whl_files] + + if not parsed_files: + print("No wheel files found, skipping index generation.") + return + + # Group by variant + variant_to_files: dict[str, list[WheelFileInfo]] = {} + for file in parsed_files: + variant = file.variant or "default" + if variant not in variant_to_files: + variant_to_files[variant] = [] + variant_to_files[variant].append(file) + + print(f"Found variants: {list(variant_to_files.keys())}") + + # sanity check for default variant + if default_variant: + if "default" in variant_to_files: + raise ValueError( + "All wheel files must have variant suffixes when `default_variant` is specified." + ) + if default_variant not in variant_to_files: + raise ValueError( + f"Default variant '{default_variant}' not found among wheel files." + ) + + if alias_to_default: + if "default" not in variant_to_files: + # e.g. only some wheels are uploaded to S3 currently + print( + "[WARN] Alias to default variant specified, but no default variant found." + ) + elif alias_to_default in variant_to_files: + raise ValueError( + f"Alias variant name '{alias_to_default}' already exists among wheel files." + ) + else: + variant_to_files[alias_to_default] = variant_to_files["default"].copy() + print(f"Alias variant '{alias_to_default}' created for default variant.") + + # Generate comment in HTML header + comment_str = f" ({comment})" if comment else "" + comment_tmpl = f"Generated on {datetime.now().isoformat()}{comment_str}" + + # Generate index for each variant + subdir_names = set() + for variant, files in variant_to_files.items(): + if variant == "default": + variant_dir = index_base_dir + else: + variant_dir = index_base_dir / variant + subdir_names.add(variant) + + variant_dir.mkdir(parents=True, exist_ok=True) + + # gather all package names in this variant + packages = set(f.package_name for f in files) + if variant == "default": + # these packages should also appear in the "project list" + # generate after all variants are processed + subdir_names = subdir_names.union(packages) + else: + # generate project list for this variant directly + project_list_str = generate_project_list(sorted(packages), comment_tmpl) + with open(variant_dir / "index.html", "w") as f: + f.write(project_list_str) + + for package in packages: + # filter files belonging to this package only + package_files = [f for f in files if f.package_name == package] + package_dir = variant_dir / package + package_dir.mkdir(parents=True, exist_ok=True) + index_str, metadata_str = generate_package_index_and_metadata( + package_files, wheel_base_dir, package_dir, comment + ) + with open(package_dir / "index.html", "w") as f: + f.write(index_str) + with open(package_dir / "metadata.json", "w") as f: + f.write(metadata_str) + + # Generate top-level project list index + project_list_str = generate_project_list(sorted(subdir_names), comment_tmpl) + with open(index_base_dir / "index.html", "w") as f: + f.write(project_list_str) + + +if __name__ == "__main__": + """ + Arguments: + --version : version string for the current build (e.g., commit hash) + --current-objects : path to JSON file containing current S3 objects listing in this version directory + --output-dir : directory to store generated index files + --alias-to-default : (optional) alias variant name for the default variant + --comment : (optional) comment string to include in generated HTML files + """ + + parser = argparse.ArgumentParser( + description="Process nightly build wheel files to generate indices." + ) + parser.add_argument( + "--version", + type=str, + required=True, + help="Version string for the current build (e.g., commit hash)", + ) + parser.add_argument( + "--current-objects", + type=str, + required=True, + help="Path to JSON file containing current S3 objects listing in this version directory", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to store generated index files", + ) + parser.add_argument( + "--alias-to-default", + type=str, + default=None, + help="Alias variant name for the default variant", + ) + parser.add_argument( + "--comment", + type=str, + default="", + help="Optional comment string to include in generated HTML files", + ) + + args = parser.parse_args() + + version = args.version + if "/" in version or "\\" in version: + raise ValueError("Version string must not contain slashes.") + current_objects_path = Path(args.current_objects) + output_dir = Path(args.output_dir) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + # Read current objects JSON + with open(current_objects_path) as f: + current_objects: dict[str, list[dict[str, Any]]] = json.load(f) + + # current_objects looks like from list_objects_v2 S3 API: + """ + "Contents": [ + { + "Key": "e2f56c309d2a28899c68975a7e104502d56deb8f/vllm-0.11.2.dev363+ge2f56c309-cp38-abi3-manylinux1_x86_64.whl", + "LastModified": "2025-11-28T14:00:32+00:00", + "ETag": "\"37a38339c7cdb61ca737021b968075df-52\"", + "ChecksumAlgorithm": [ + "CRC64NVME" + ], + "ChecksumType": "FULL_OBJECT", + "Size": 435649349, + "StorageClass": "STANDARD" + }, + ... + ] + """ + + # Extract wheel file keys + wheel_files = [] + for item in current_objects.get("Contents", []): + key: str = item["Key"] + if key.endswith(".whl"): + wheel_files.append(key.split("/")[-1]) # only the filename is used + + print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}") + + # keep only "official" files for a non-nightly version (specifed by cli args) + PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$") + if PY_VERSION_RE.match(version): + # upload-wheels.sh ensures no "dev" is in args.version + wheel_files = list( + filter(lambda x: version in x and "dev" not in x, wheel_files) + ) + print(f"Non-nightly version detected, wheel files used: {wheel_files}") + else: + print("Nightly version detected, keeping all wheel files.") + + # Generate index and metadata, assuming wheels and indices are stored as: + # s3://vllm-wheels/{version}/ + # s3://vllm-wheels// + wheel_base_dir = Path(output_dir).parent / version + index_base_dir = Path(output_dir) + + generate_index_and_metadata( + whl_files=wheel_files, + wheel_base_dir=wheel_base_dir, + index_base_dir=index_base_dir, + default_variant=None, + alias_to_default=args.alias_to_default, + comment=args.comment.strip(), + ) + print(f"Successfully generated index and metadata in {output_dir}") diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh index d0036f24c8d0..b6274d698d01 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh @@ -7,53 +7,57 @@ set -ex # allow to bind to different cores CORE_RANGE=${CORE_RANGE:-0-16} OMP_CORE_RANGE=${OMP_CORE_RANGE:-0-16} -NUMA_NODE=${NUMA_NODE:-0} -export CMAKE_BUILD_PARALLEL_LEVEL=32 +export CMAKE_BUILD_PARALLEL_LEVEL=16 # Setup cleanup remove_docker_container() { set -e; - docker rm -f cpu-test-"$NUMA_NODE" || true; + docker rm -f cpu-test || true; } trap remove_docker_container EXIT remove_docker_container # Try building the docker image -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . +docker build --tag cpu-test --target vllm-test -f docker/Dockerfile.cpu . -# Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +# Run the image +docker run -itd --cpuset-cpus="$CORE_RANGE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test cpu-test function cpu_tests() { set -e - export NUMA_NODE=$2 - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test bash -c " set -e pip list" # offline inference - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test bash -c " set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run model tests + docker exec cpu-test bash -c " + set -e + pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model" + # Run kernel tests - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test bash -c " set -e pytest -x -v -s tests/kernels/test_onednn.py - pytest -x -v -s tests/kernels/attention/test_cpu_attn.py" + pytest -x -v -s tests/kernels/attention/test_cpu_attn.py + pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic" # basic online serving - docker exec cpu-test-"$NUMA_NODE" bash -c ' + docker exec cpu-test bash -c ' set -e - VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve meta-llama/Llama-3.2-3B-Instruct --max-model-len 2048 & + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 & server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ - --model meta-llama/Llama-3.2-3B-Instruct \ + --model Qwen/Qwen3-0.6B \ --num-prompts 20 \ --endpoint /v1/completions kill -s SIGTERM $server_pid &' @@ -61,4 +65,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c cpu_tests diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 2267718f75ca..438fe522c870 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -21,8 +21,8 @@ trap remove_docker_container EXIT remove_docker_container # Try building the docker image -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-npu-test.sh b/.buildkite/scripts/hardware_ci/run-npu-test.sh index 29c8f5ed5a91..0db1abe37ba1 100644 --- a/.buildkite/scripts/hardware_ci/run-npu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-npu-test.sh @@ -74,6 +74,7 @@ FROM ${BASE_IMAGE_NAME} # Define environments ENV DEBIAN_FRONTEND=noninteractive +ENV SOC_VERSION="ascend910b1" RUN pip config set global.index-url http://cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_PORT}/pypi/simple && \ pip config set global.trusted-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local && \ diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index d49f3e2f47cf..dfc9db512d1e 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -35,9 +35,10 @@ docker run \ echo $ZE_AFFINITY_MASK pip install tblib==3.1.0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager - python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager cd tests pytest -v -s v1/core @@ -46,6 +47,6 @@ docker run \ pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py pytest -v -s v1/test_serial_utils.py ' diff --git a/.buildkite/scripts/run-prime-rl-test.sh b/.buildkite/scripts/run-prime-rl-test.sh index 5b25c358fc4a..3fb7c82c8d33 100755 --- a/.buildkite/scripts/run-prime-rl-test.sh +++ b/.buildkite/scripts/run-prime-rl-test.sh @@ -12,6 +12,11 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git" PRIME_RL_DIR="${REPO_ROOT}/prime-rl" +if command -v rocm-smi &> /dev/null || command -v rocminfo &> /dev/null; then + echo "AMD GPU detected. Prime-RL currently only supports NVIDIA. Skipping..." + exit 0 +fi + echo "Setting up Prime-RL integration test environment..." # Clean up any existing Prime-RL directory diff --git a/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh new file mode 100644 index 000000000000..937a43d1a322 --- /dev/null +++ b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -euxo pipefail + +# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] +THRESHOLD=${1:-0.25} +NUM_Q=${2:-1319} +PORT=${3:-8040} +OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} +mkdir -p "${OUT_DIR}" + +wait_for_server() { + local port=$1 + timeout 600 bash -c ' + until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do + sleep 1 + done' +} + +MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct" + +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi + +cleanup() { + if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then + kill "${SERVER_PID}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${SERVER_PID}" 2>/dev/null || break + sleep 0.5 + done + kill -9 "${SERVER_PID}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +for BACK in "${BACKENDS[@]}"; do + VLLM_DEEP_GEMM_WARMUP=skip \ + VLLM_ALL2ALL_BACKEND=$BACK \ + vllm serve "$MODEL" \ + --enforce-eager \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --enable-eplb \ + --eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \ + --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \ + --trust-remote-code \ + --max-model-len 2048 \ + --gpu-memory-utilization 0.9 \ + --port $PORT & + SERVER_PID=$! + wait_for_server $PORT + + TAG=$(echo "$MODEL" | tr '/: \\n' '_____') + OUT="${OUT_DIR}/${TAG}_${BACK}.json" + python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT} + python3 - <= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}" +PY + + cleanup + SERVER_PID= + sleep 1 + PORT=$((PORT+1)) +done diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 945c5e48c009..3a218a4bb2e6 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -2,6 +2,28 @@ set -ex +# ======== part 0: setup ======== + +BUCKET="vllm-wheels" +INDICES_OUTPUT_DIR="indices" +DEFAULT_VARIANT_ALIAS="cu129" # align with vLLM_MAIN_CUDA_VERSION in vllm/envs.py +PYTHON=${PYTHON_PROG:=python3} # try to read from env var, otherwise use python3 +SUBPATH=$BUILDKITE_COMMIT +S3_COMMIT_PREFIX="s3://$BUCKET/$SUBPATH/" + +# detect if python3.10+ is available +has_new_python=$($PYTHON -c "print(1 if __import__('sys').version_info >= (3,12) else 0)") +if [[ "$has_new_python" -eq 0 ]]; then + # use new python from docker + docker pull python:3-slim + PYTHON="docker run --rm -v $(pwd):/app -w /app python:3-slim python3" +fi + +echo "Using python interpreter: $PYTHON" +echo "Python version: $($PYTHON --version)" + +# ========= part 1: collect, rename & upload the wheel ========== + # Assume wheels are in artifacts/dist/*.whl wheel_files=(artifacts/dist/*.whl) @@ -10,74 +32,76 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}" exit 1 fi - -# Get the single wheel file wheel="${wheel_files[0]}" -# Detect architecture and rename 'linux' to appropriate manylinux version -arch=$(uname -m) -if [[ $arch == "x86_64" ]]; then - manylinux_version="manylinux1" -elif [[ $arch == "aarch64" ]]; then - manylinux_version="manylinux2014" -else - echo "Warning: Unknown architecture $arch, using manylinux1 as default" - manylinux_version="manylinux1" -fi +# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31 +# we also accept params as manylinux tag +# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels +manylinux_version="${1:-manylinux_2_31}" # Rename 'linux' to the appropriate manylinux version in the wheel filename +if [[ "$wheel" != *"linux"* ]]; then + echo "Error: Wheel filename does not contain 'linux': $wheel" + exit 1 +fi new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" +echo "Renamed wheel to: $wheel" # Extract the version from the wheel version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2) -echo "Version: $version" - -normal_wheel="$wheel" # Save the original wheel filename - -# If the version contains "dev", rename it to v1.0.0.dev for consistency -if [[ $version == *dev* ]]; then - suffix="${version##*.}" - if [[ $suffix == cu* ]]; then - new_version="1.0.0.dev+${suffix}" - else - new_version="1.0.0.dev" - fi - new_wheel="${wheel/$version/$new_version}" - # use cp to keep both files in the artifacts directory - cp -- "$wheel" "$new_wheel" - wheel="$new_wheel" - version="$new_version" -fi +echo "Version in wheel: $version" +pure_version="${version%%+*}" +echo "Pure version (without variant): $pure_version" -# Upload the wheel to S3 -python3 .buildkite/generate_index.py --wheel "$normal_wheel" +# copy wheel to its own bucket +aws s3 cp "$wheel" "$S3_COMMIT_PREFIX" -# generate index for this commit -aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" +# ========= part 2: generate and upload indices ========== +# generate indices for all existing wheels in the commit directory +# this script might be run multiple times if there are multiple variants being built +# so we need to guarantee there is little chance for "TOCTOU" issues +# i.e., one process is generating indices while another is uploading a new wheel +# so we need to ensure no time-consuming operations happen below -if [[ $normal_wheel == *"cu129"* ]]; then - # only upload index.html for cu129 wheels (default wheels) as it - # is available on both x86 and arm64 - aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" - aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" +# list all wheels in the commit directory +echo "Existing wheels on S3:" +aws s3 ls "$S3_COMMIT_PREFIX" +obj_json="objects.json" +aws s3api list-objects-v2 --bucket "$BUCKET" --prefix "$SUBPATH/" --delimiter / --output json > "$obj_json" +mkdir -p "$INDICES_OUTPUT_DIR" + +# call script to generate indicies for all existing wheels +# this indices have relative paths that could work as long as it is next to the wheel directory in s3 +# i.e., the wheels are always in s3://vllm-wheels// +# and indices can be placed in //, or /nightly/, or // +if [[ ! -z "$DEFAULT_VARIANT_ALIAS" ]]; then + alias_arg="--alias-to-default $DEFAULT_VARIANT_ALIAS" else - echo "Skipping index files for non-cu129 wheels" + alias_arg="" fi -# generate index for nightly -aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" -aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" +# HACK: we do not need regex module here, but it is required by pre-commit hook +# To avoid any external dependency, we simply replace it back to the stdlib re module +sed -i 's/import regex as re/import re/g' .buildkite/scripts/generate-nightly-index.py +$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "commit $BUILDKITE_COMMIT" $alias_arg -if [[ $normal_wheel == *"cu129"* ]]; then - # only upload index.html for cu129 wheels (default wheels) as it - # is available on both x86 and arm64 - aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" -else - echo "Skipping index files for non-cu129 wheels" +# copy indices to // unconditionally +echo "Uploading indices to $S3_COMMIT_PREFIX" +aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "$S3_COMMIT_PREFIX" + +# copy to /nightly/ only if it is on the main branch and not a PR +if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]]; then + echo "Uploading indices to overwrite /nightly/" + aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/" fi -aws s3 cp "$wheel" "s3://vllm-wheels/$version/" -aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" +# re-generate and copy to // only if it does not have "dev" in the version +if [[ "$version" != *"dev"* ]]; then + echo "Re-generating indices for /$pure_version/" + rm -rf "$INDICES_OUTPUT_DIR/*" + mkdir -p "$INDICES_OUTPUT_DIR" + $PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg + aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/" +fi diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 4ddf11c0b268..3c9b8cbedcf0 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -39,9 +39,9 @@ steps: # if this test fails, it means the nightly torch version is not compatible with some # of the dependencies. Please check the error message and add the package to whitelist # in /vllm/tools/pre_commit/generate_nightly_torch_test.py - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking soft_fail: true source_file_dependencies: - requirements/nightly_torch_test.txt @@ -50,9 +50,9 @@ steps: - label: Async Engine, Inputs, Utils, Worker Test # 10min timeout_in_minutes: 15 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/multimodal @@ -61,17 +61,19 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - timeout_in_minutes: 10 - mirror_hardwares: [amdexperimental, amdproduction] +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -80,6 +82,8 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config @@ -113,9 +117,9 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - label: Entrypoints Unit Tests # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking timeout_in_minutes: 10 working_dir: "/vllm-workspace/tests" fast_check: true @@ -212,6 +216,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -250,9 +255,9 @@ steps: - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep - label: EPLB Algorithm Test # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" source_file_dependencies: @@ -308,28 +313,25 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 25min - timeout_in_minutes: 40 +- label: Engine Test # 9min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: - vllm/ - tests/engine - - tests/tokenization - tests/test_sequence - tests/test_config - tests/test_logger - tests/test_vllm_port commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenization -- label: V1 Test e2e + engine # 30min - timeout_in_minutes: 45 +- label: V1 Test e2e + engine # 65min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] - agent_pool: mi325_1 + agent_pool: mi325_4 # grade: Blocking source_file_dependencies: - vllm/ @@ -342,9 +344,9 @@ steps: - label: V1 Test entrypoints # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/v1 @@ -392,6 +394,21 @@ steps: commands: - pytest -v -s v1/attention +- label: Batch Invariance Tests (H100) # 10min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + timeout_in_minutes: 25 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - vllm/model_executor/layers + - tests/v1/determinism/ + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pip install pytest-timeout pytest-forked + - pytest -v -s v1/determinism/test_batch_invariance.py + - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py + - label: V1 Test attention (B200) # 10min timeout_in_minutes: 30 gpu: b200 @@ -402,9 +419,9 @@ steps: - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - label: V1 Test others (CPU) # 5 mins - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/v1 @@ -420,29 +437,34 @@ steps: - label: Examples Test # 30min timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking working_dir: "/vllm-workspace/examples" source_file_dependencies: - vllm/entrypoints + - vllm/multimodal - examples/ commands: - pip install tensorizer # for tensorizer test + # for basic + - python3 offline_inference/basic/chat.py - python3 offline_inference/basic/generate.py --model facebook/opt-125m - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 - - python3 offline_inference/basic/chat.py - - python3 offline_inference/prefix_caching.py - - python3 offline_inference/llm_engine_example.py + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py + # for multi-modal models - python3 offline_inference/audio_language.py --seed 0 - python3 offline_inference/vision_language.py --seed 0 - - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - - python3 offline_inference/basic/classify.py - - python3 offline_inference/basic/embed.py - - python3 offline_inference/basic/score.py + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 @@ -496,7 +518,7 @@ steps: - label: PyTorch Compilation Unit Tests # 15min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -513,7 +535,7 @@ steps: - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -569,7 +591,7 @@ steps: - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 # grade: Blocking source_file_dependencies: @@ -596,7 +618,7 @@ steps: - label: Kernels MoE Test %N # 40min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 # grade: Blocking source_file_dependencies: @@ -623,6 +645,26 @@ steps: commands: - pytest -v -s kernels/mamba +- label: Kernels DeepGEMM Test (H100) # Nvidia-centric +# Not replicating for CUTLAS & CuTe + timeout_in_minutes: 45 + gpu: h100 + num_gpus: 1 + source_file_dependencies: + - tools/install_deepgemm.sh + - vllm/utils/deep_gemm.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization + - tests/kernels/quantization/test_block_fp8.py + - tests/kernels/moe/test_deepgemm.py + - tests/kernels/moe/test_batched_deepgemm.py + - tests/kernels/attention/test_deepgemm_attention.py + commands: + - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm + - pytest -v -s kernels/moe/test_deepgemm.py + - pytest -v -s kernels/moe/test_batched_deepgemm.py + - pytest -v -s kernels/attention/test_deepgemm_attention.py + - label: Model Executor Test # 23min timeout_in_minutes: 35 torch_nightly: true @@ -681,16 +723,18 @@ steps: # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - uv pip install --system torchao==0.13.0 + - uv pip install --system conch-triton-kernels - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py -- label: LM Eval Small Models # 15min - timeout_in_minutes: 20 - mirror_hardwares: [amdexperimental, amdproduction] +- label: LM Eval Small Models # 53min + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization + autorun_on_main: true commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 @@ -703,7 +747,7 @@ steps: - csrc/ - vllm/entrypoints/openai/ - vllm/model_executor/models/whisper.py - commands: # LMEval + commands: # LMEval+Transcription WER check # Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442 - pytest -s entrypoints/openai/correctness/ @@ -717,19 +761,7 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) # 5 mins - mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 - # grade: Blocking - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use ##### models test ##### @@ -900,6 +932,18 @@ steps: commands: - pytest -v -s models/language/pooling_mteb_test +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + - label: Multi-Modal Processor Test # 44min timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] @@ -927,8 +971,8 @@ steps: - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work -- label: Multi-Modal Accuracy Eval (Small Models) # 10min - timeout_in_minutes: 70 +- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min + timeout_in_minutes: 180 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -940,7 +984,8 @@ steps: commands: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 -- label: Multi-Modal Models Test (Extended) 1 +- label: Multi-Modal Models Test (Extended) 1 # 60min + timeout_in_minutes: 120 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -964,7 +1009,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' -- label: Multi-Modal Models Test (Extended) 3 +- label: Multi-Modal Models Test (Extended) 3 # 75min + timeout_in_minutes: 150 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -1056,6 +1102,7 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py - label: Blackwell Fusion and Compile Tests # 30 min timeout_in_minutes: 40 @@ -1065,11 +1112,18 @@ steps: - csrc/quantization/fp4/ - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/worker/ + - vllm/v1/cudagraph_dispatcher.py - vllm/compilation/ # can affect pattern matching - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/test_fusion_attn.py + - tests/compile/test_silu_mul_quant_fusion.py + - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/distributed/test_fusions_e2e.py + - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi - pytest -v -s tests/compile/test_fusion_attn.py @@ -1080,7 +1134,7 @@ steps: # Wrap with quotes to escape yaml - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) - - pytest -v -s tests/compile/distributed/test_full_graph.py::test_fp8_kv_scale_compile + - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell Fusion E2E Tests # 30 min timeout_in_minutes: 40 @@ -1098,17 +1152,15 @@ steps: - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/distributed/test_fusions_e2e.py - - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi # Run all e2e fusion tests - - pytest -v -s tests/compile/test_fusions_e2e.py + - pytest -v -s tests/compile/distributed/test_fusions_e2e.py -- label: ROCm GPT-OSS Eval +- label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 working_dir: "/vllm-workspace/" - agent_pool: mi325_1 - mirror_hardwares: [amdexperimental, amdproduction] + gpu: b200 optional: true # run on nightlies source_file_dependencies: - tests/evals/gpt_oss @@ -1117,7 +1169,7 @@ steps: - vllm/v1/attention/backends/flashinfer.py commands: - uv pip install --system 'gpt-oss[eval]==0.0.5' - - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 - label: Blackwell Quantized MoE Test timeout_in_minutes: 60 @@ -1217,6 +1269,7 @@ steps: - tests/v1/worker/test_worker_memory_snapshot.py commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py @@ -1252,7 +1305,7 @@ steps: - label: Plugin Tests (2 GPUs) # 40min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1321,14 +1374,14 @@ steps: - pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_olmoe_tp.py - # Disabled for now because MXFP4 backend on non-cuda platform + # Disabled for now because MXFP4 backend on non-cuda platform # doesn't support LoRA yet #- pytest -v -s -x lora/test_gptoss_tp.py - label: Weight Loading Multiple GPU Test # 33min timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1387,12 +1440,13 @@ steps: - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - pytest -v -s -x lora/test_mixtral.py + - label: LM Eval Large Models # optional - mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_4 - # grade: Blocking gpu: a100 optional: true + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking num_gpus: 4 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: @@ -1404,11 +1458,11 @@ steps: ##### H100 test ##### - label: LM Eval Large Models (H100) # optional - mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_4 - # grade: Blocking gpu: h100 optional: true + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking num_gpus: 4 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: @@ -1418,6 +1472,7 @@ steps: - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 + ##### H200 test ##### - label: Distributed Tests (H200) # optional mirror_hardwares: [amdexperimental] @@ -1428,14 +1483,14 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/compile/distributed/test_async_tp.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - - pytest -v -s tests/compile/distributed/test_sequence_parallel.py + - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - HIP_VISIBLE_DEVICES=0,1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### @@ -1449,6 +1504,57 @@ steps: - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py - pytest -v -s tests/v1/distributed/test_dbo.py +##### E2E Eval Tests ##### +- label: LM Eval Small Models (1 Card) # 15min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + +- label: LM Eval Large Models (4 Card) + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking + gpu: a100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: ROCm LM Eval Large Models (8 Card) + mirror_hardwares: [amdproduction] + agent_pool: mi325_8 + num_gpus: 8 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-rocm.txt --tp-size=8 + +- label: ROCm GPT-OSS Eval + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + agent_pool: mi325_1 + mirror_hardwares: [amdexperimental, amdproduction] + optional: true # run on nightlies + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min mirror_hardwares: [amdexperimental] @@ -1463,9 +1569,8 @@ steps: - .buildkite/scripts/run-prime-rl-test.sh commands: - bash .buildkite/scripts/run-prime-rl-test.sh - - label: DeepSeek V2-Lite Accuracy - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking timeout_in_minutes: 60 @@ -1476,8 +1581,8 @@ steps: commands: - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 -- label: Qwen3-30B-A3B-FP8-block Accuracy - mirror_hardwares: [amdexperimental] +- label: Qwen3-30B-A3B-FP8-block Accuracy (H100) + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking timeout_in_minutes: 60 @@ -1487,3 +1592,35 @@ steps: working_dir: "/vllm-workspace" commands: - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 + +- label: Qwen3-30B-A3B-FP8-block Accuracy (B200) + timeout_in_minutes: 60 + gpu: b200 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 + +- label: DeepSeek V2-Lite Async EPLB Accuracy + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030 + +- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e444becd9867..2dcca5711b3d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,14 +57,16 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - timeout_in_minutes: 10 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min + timeout_in_minutes: 30 source_file_dependencies: - vllm/ - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -73,6 +75,8 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config @@ -276,21 +280,18 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 25min - timeout_in_minutes: 40 +- label: Engine Test # 9min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/engine - - tests/tokenization - tests/test_sequence - tests/test_config - tests/test_logger - tests/test_vllm_port commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenization - label: V1 Test e2e + engine # 30min timeout_in_minutes: 45 @@ -351,7 +352,8 @@ steps: timeout_in_minutes: 25 gpu: h100 source_file_dependencies: - - vllm/ + - vllm/v1/attention + - vllm/model_executor/layers - tests/v1/determinism/ commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn @@ -388,23 +390,28 @@ steps: working_dir: "/vllm-workspace/examples" source_file_dependencies: - vllm/entrypoints + - vllm/multimodal - examples/ commands: - pip install tensorizer # for tensorizer test + # for basic + - python3 offline_inference/basic/chat.py - python3 offline_inference/basic/generate.py --model facebook/opt-125m - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 - - python3 offline_inference/basic/chat.py - - python3 offline_inference/prefix_caching.py - - python3 offline_inference/llm_engine_example.py + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py + # for multi-modal models - python3 offline_inference/audio_language.py --seed 0 - python3 offline_inference/vision_language.py --seed 0 - - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - - python3 offline_inference/basic/classify.py - - python3 offline_inference/basic/embed.py - - python3 offline_inference/basic/score.py + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 @@ -463,7 +470,9 @@ steps: # tests covered elsewhere. # Use `find` to launch multiple instances of pytest so that # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 - - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" + # However, find does not normally propagate error codes, so we combine it with xargs + # (using -0 for proper path handling) + - "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'" - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 @@ -477,7 +486,9 @@ steps: # as it is a heavy test that is covered in other steps. # Use `find` to launch multiple instances of pytest so that # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 - - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;" + # However, find does not normally propagate error codes, so we combine it with xargs + # (using -0 for proper path handling) + - "find compile/fullgraph -maxdepth 1 -name 'test_*.py' -not -name 'test_full_graph.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'" - label: PyTorch Fullgraph Test # 27min timeout_in_minutes: 40 @@ -632,6 +643,7 @@ steps: # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 + - uv pip install --system conch-triton-kernels - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - label: LM Eval Small Models # 53min @@ -662,16 +674,7 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) # 5 mins - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use ##### models test ##### @@ -682,6 +685,7 @@ steps: source_file_dependencies: - vllm/ - tests/models/test_initialization.py + - tests/models/registry.py commands: # Run a subset of model initialization tests - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset @@ -694,6 +698,7 @@ steps: - vllm/model_executor/models/ - vllm/transformers_utils/ - tests/models/test_initialization.py + - tests/models/registry.py commands: # Only when vLLM model source is modified - test initialization of a large # subset of supported models (the complement of the small subset in the above @@ -819,14 +824,24 @@ steps: commands: - pytest -v -s models/language/pooling_mteb_test -- label: Multi-Modal Processor Test # 44min +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'" + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + +- label: Multi-Modal Processor Test timeout_in_minutes: 60 source_file_dependencies: - vllm/ - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s models/multimodal/processing + - pytest -v -s models/multimodal/processing/test_tensor_schema.py - label: Multi-Modal Models Test (Standard) # 60min timeout_in_minutes: 80 @@ -903,11 +918,12 @@ steps: - label: Transformers Nightly Models Test working_dir: "/vllm-workspace/" optional: true + soft_fail: true commands: - pip install --upgrade git+https://github.com/huggingface/transformers - - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)' + - pytest -v -s tests/models/test_initialization.py - pytest -v -s tests/models/test_transformers.py - # - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/processing/ - pytest -v -s tests/models/multimodal/test_mapping.py - python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl @@ -971,7 +987,6 @@ steps: - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - - vllm/model_executor/layers/fused_moe/layer.py - tests/compile/test_fusion_attn.py - tests/compile/test_silu_mul_quant_fusion.py - tests/compile/distributed/test_fusion_all_reduce.py @@ -1302,11 +1317,11 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/compile/distributed/test_async_tp.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py - - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - - pytest -v -s tests/distributed/test_sequence_parallel.py + - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - pytest -v -s tests/v1/distributed/test_dbo.py @@ -1326,6 +1341,7 @@ steps: - label: Prime-RL Integration Test # 15min timeout_in_minutes: 30 optional: true + soft_fail: true num_gpus: 2 working_dir: "/vllm-workspace" source_file_dependencies: @@ -1359,4 +1375,4 @@ steps: num_gpus: 2 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 \ No newline at end of file + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 diff --git a/.buildkite/test_areas/attention.yaml b/.buildkite/test_areas/attention.yaml new file mode 100644 index 000000000000..6e444eae14c7 --- /dev/null +++ b/.buildkite/test_areas/attention.yaml @@ -0,0 +1,21 @@ +group: Attention +depends_on: + - image-build +steps: +- label: V1 attention (H100) + timeout_in_minutes: 30 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - pytest -v -s v1/attention + +- label: V1 attention (B200) + timeout_in_minutes: 30 + gpu: b200 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this diff --git a/.buildkite/test_areas/basic_correctness.yaml b/.buildkite/test_areas/basic_correctness.yaml new file mode 100644 index 000000000000..759d2b535871 --- /dev/null +++ b/.buildkite/test_areas/basic_correctness.yaml @@ -0,0 +1,16 @@ +group: Basic Correctness +depends_on: + - image-build +steps: +- label: Basic Correctness + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_basic_correctness + - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_cumem.py + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s basic_correctness/test_cumem.py + - pytest -v -s basic_correctness/test_basic_correctness.py + - pytest -v -s basic_correctness/test_cpu_offload.py diff --git a/.buildkite/test_areas/benchmarks.yaml b/.buildkite/test_areas/benchmarks.yaml new file mode 100644 index 000000000000..574b642d407b --- /dev/null +++ b/.buildkite/test_areas/benchmarks.yaml @@ -0,0 +1,19 @@ +group: Benchmarks +depends_on: + - image-build +steps: +- label: Benchmarks + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/.buildkite" + source_file_dependencies: + - benchmarks/ + commands: + - bash scripts/run-benchmarks.sh + +- label: Benchmarks CLI Test + timeout_in_minutes: 20 + source_file_dependencies: + - vllm/ + - tests/benchmarks/ + commands: + - pytest -v -s benchmarks/ diff --git a/.buildkite/test_areas/compile.yaml b/.buildkite/test_areas/compile.yaml new file mode 100644 index 000000000000..0ba00925a483 --- /dev/null +++ b/.buildkite/test_areas/compile.yaml @@ -0,0 +1,57 @@ +group: Compile +depends_on: + - image-build +steps: +- label: Fusion and Compile Tests (B200) + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/worker/ + - vllm/v1/cudagraph_dispatcher.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/test_fusion_attn.py + - tests/compile/test_silu_mul_quant_fusion.py + - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/distributed/test_fusions_e2e.py + - tests/compile/fullgraph/test_full_graph.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time + # Wrap with quotes to escape yaml + - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile + +- label: Fusion E2E (2 GPUs)(B200) + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true + num_gpus: 2 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/distributed/test_fusions_e2e.py + commands: + - nvidia-smi + # Run all e2e fusion tests + - pytest -v -s tests/compile/distributed/test_fusions_e2e.py + diff --git a/.buildkite/test_areas/cuda.yaml b/.buildkite/test_areas/cuda.yaml new file mode 100644 index 000000000000..50c0c338c243 --- /dev/null +++ b/.buildkite/test_areas/cuda.yaml @@ -0,0 +1,22 @@ +group: CUDA +depends_on: + - image-build +steps: +- label: Platform Tests (CUDA) + timeout_in_minutes: 15 + source_file_dependencies: + - vllm/ + - tests/cuda + commands: + - pytest -v -s cuda/test_cuda_context.py + +- label: Cudagraph + timeout_in_minutes: 20 + source_file_dependencies: + - tests/v1/cudagraph + - vllm/v1/cudagraph_dispatcher.py + - vllm/config/compilation.py + - vllm/compilation + commands: + - pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py + - pytest -v -s v1/cudagraph/test_cudagraph_mode.py \ No newline at end of file diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml new file mode 100644 index 000000000000..2cc90698d916 --- /dev/null +++ b/.buildkite/test_areas/distributed.yaml @@ -0,0 +1,199 @@ +group: Distributed +depends_on: + - image-build +steps: +- label: Distributed Comm Ops + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed + - tests/distributed + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py + +- label: Distributed (2 GPUs) + timeout_in_minutes: 90 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/compilation/ + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/worker/worker_base.py + - vllm/v1/engine/ + - vllm/v1/worker/ + - tests/compile/fullgraph/test_basic_correctness.py + - tests/compile/test_wrapper.py + - tests/distributed/ + - tests/entrypoints/llm/test_collective_rpc.py + - tests/v1/distributed + - tests/v1/entrypoints/openai/test_multi_api_servers.py + - tests/v1/shutdown + - tests/v1/worker/test_worker_memory_snapshot.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py + - pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s ./compile/fullgraph/test_basic_correctness.py + - pytest -v -s ./compile/test_wrapper.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - pytest -v -s distributed/test_sequence_parallel.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown + - pytest -v -s v1/worker/test_worker_memory_snapshot.py + +- label: Distributed Tests (4 GPUs) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - tests/distributed/test_utils + - tests/distributed/test_pynccl + - tests/distributed/test_events + - tests/compile/fullgraph/test_basic_correctness.py + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py + - tests/examples/offline_inference/data_parallel.py + - tests/v1/distributed + - tests/v1/engine/test_engine_core_client.py + - tests/distributed/test_symm_mem_allreduce.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 + # test with torchrun tp=2 and external_dp=2 + - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=4 and dp=1 + - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2, pp=2 and dp=1 + - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=1 and dp=4 with ep + - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2 and dp=2 with ep + - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with internal dp + - python3 ../examples/offline_inference/data_parallel.py --enforce-eager + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py + - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp + - pytest -v -s distributed/test_utils.py + - pytest -v -s compile/fullgraph/test_basic_correctness.py + - pytest -v -s distributed/test_pynccl.py + - pytest -v -s distributed/test_events.py + - pytest -v -s distributed/test_symm_mem_allreduce.py + # TODO: create a dedicated test section for multi-GPU example tests + # when we have multiple distributed example tests + - cd ../examples/offline_inference + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + +- label: Distributed Tests (8 GPUs)(H100) + timeout_in_minutes: 10 + gpu: h100 + num_gpus: 8 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - examples/offline_inference/torchrun_dp_example.py + - vllm/config/parallel.py + - vllm/distributed/ + - vllm/v1/engine/llm_engine.py + - vllm/v1/executor/uniproc_executor.py + - vllm/v1/worker/gpu_worker.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 + # test with torchrun tp=2 and dp=4 with ep + - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep + +- label: Distributed Tests (4 GPUs)(A100) + gpu: a100 + optional: true + num_gpus: 4 + source_file_dependencies: + - vllm/ + commands: + # NOTE: don't test llama model here, it seems hf implementation is buggy + # see https://github.com/vllm-project/vllm/pull/5689 for details + - pytest -v -s distributed/test_custom_all_reduce.py + - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - pytest -v -s -x lora/test_mixtral.py + +- label: Distributed Tests (2 GPUs)(H200) + gpu: h200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py + - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py + - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4' + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py + - pytest -v -s tests/distributed/test_context_parallel.py + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/v1/distributed/test_dbo.py + +- label: Distributed Tests (2 GPUs)(B200) + gpu: b200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/v1/distributed/test_dbo.py + +- label: 2 Node Test (4 GPUs) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + num_nodes: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + - tests/examples/offline_inference/data_parallel.py + commands: + - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:0bec63fa317e1fbd62e19b0fc31c43c81bf89077 "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code" + +- label: Distributed NixlConnector PD accuracy (4 GPUs) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - tests/v1/kv_connector/nixl_integration/ + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + - bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh + +- label: Pipeline + Context Parallelism (4 GPUs)) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + commands: + - pytest -v -s distributed/test_pp_cudagraph.py + - pytest -v -s distributed/test_pipeline_parallel.py \ No newline at end of file diff --git a/.buildkite/test_areas/e2e_integration.yaml b/.buildkite/test_areas/e2e_integration.yaml new file mode 100644 index 000000000000..93d389815eda --- /dev/null +++ b/.buildkite/test_areas/e2e_integration.yaml @@ -0,0 +1,59 @@ +group: E2E Integration +depends_on: + - image-build +steps: +- label: DeepSeek V2-Lite Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 + +- label: Qwen3-30B-A3B-FP8-block Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 + +- label: Qwen3-30B-A3B-FP8-block Accuracy (B200) + timeout_in_minutes: 60 + gpu: b200 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 + +- label: Prime-RL Integration (2 GPUs) + timeout_in_minutes: 30 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + source_file_dependencies: + - vllm/ + - .buildkite/scripts/run-prime-rl-test.sh + commands: + - bash .buildkite/scripts/run-prime-rl-test.sh + +- label: DeepSeek V2-Lite Async EPLB Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030 + +- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 diff --git a/.buildkite/test_areas/engine.yaml b/.buildkite/test_areas/engine.yaml new file mode 100644 index 000000000000..a028e0e4af4c --- /dev/null +++ b/.buildkite/test_areas/engine.yaml @@ -0,0 +1,26 @@ +group: Engine +depends_on: + - image-build +steps: +- label: Engine + timeout_in_minutes: 15 + source_file_dependencies: + - vllm/ + - tests/engine + - tests/test_sequence + - tests/test_config + - tests/test_logger + - tests/test_vllm_port + commands: + - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py + +- label: V1 e2e + engine + timeout_in_minutes: 45 + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e + - pytest -v -s v1/engine diff --git a/.buildkite/test_areas/entrypoints.yaml b/.buildkite/test_areas/entrypoints.yaml new file mode 100644 index 000000000000..0a789be943f3 --- /dev/null +++ b/.buildkite/test_areas/entrypoints.yaml @@ -0,0 +1,68 @@ +group: Entrypoints +depends_on: + - image-build +steps: +- label: Entrypoints Unit Tests + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/entrypoints + - tests/entrypoints/ + commands: + - pytest -v -s entrypoints/openai/tool_parsers + - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling + +- label: Entrypoints Integration (LLM) + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/ + - tests/entrypoints/llm + - tests/entrypoints/offline_mode + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + +- label: Entrypoints Integration (API Server) + timeout_in_minutes: 130 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/ + - tests/entrypoints/openai + - tests/entrypoints/test_chat_utils + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/ + - pytest -v -s entrypoints/test_chat_utils.py + + +- label: Entrypoints Integration (Pooling) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/ + - tests/entrypoints/pooling + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/pooling + + +- label: Entrypoints V1 + timeout_in_minutes: 50 + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - pytest -v -s v1/entrypoints + +- label: OpenAI API Correctness + timeout_in_minutes: 30 + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ diff --git a/.buildkite/test_areas/expert_parallelism.yaml b/.buildkite/test_areas/expert_parallelism.yaml new file mode 100644 index 000000000000..feb8252148c7 --- /dev/null +++ b/.buildkite/test_areas/expert_parallelism.yaml @@ -0,0 +1,23 @@ +group: Expert Parallelism +depends_on: + - image-build +steps: +- label: EPLB Algorithm + timeout_in_minutes: 15 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + +- label: EPLB Execution + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + - pytest -v -s distributed/test_eplb_spec_decode.py \ No newline at end of file diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml new file mode 100644 index 000000000000..7ca099516d64 --- /dev/null +++ b/.buildkite/test_areas/kernels.yaml @@ -0,0 +1,117 @@ +group: Kernels +depends_on: + - image-build +steps: +- label: Kernels Core Operation Test + timeout_in_minutes: 75 + source_file_dependencies: + - csrc/ + - tests/kernels/core + - tests/kernels/test_top_k_per_row.py + commands: + - pytest -v -s kernels/core kernels/test_top_k_per_row.py + +- label: Kernels Attention Test %N + timeout_in_minutes: 35 + source_file_dependencies: + - csrc/attention/ + - vllm/attention + - vllm/v1/attention + - tests/kernels/attention + commands: + - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Quantization Test %N + timeout_in_minutes: 90 + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/layers/quantization + - tests/kernels/quantization + commands: + - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels MoE Test %N + timeout_in_minutes: 60 + source_file_dependencies: + - csrc/quantization/cutlass_w8a8/moe/ + - csrc/moe/ + - tests/kernels/moe + - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ + - vllm/envs.py + - vllm/config + commands: + - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Mamba Test + timeout_in_minutes: 45 + source_file_dependencies: + - csrc/mamba/ + - tests/kernels/mamba + - vllm/model_executor/layers/mamba/ops + commands: + - pytest -v -s kernels/mamba + +- label: Kernels DeepGEMM Test (H100) + timeout_in_minutes: 45 + gpu: h100 + num_gpus: 1 + source_file_dependencies: + - tools/install_deepgemm.sh + - vllm/utils/deep_gemm.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization + - tests/kernels/quantization/test_block_fp8.py + - tests/kernels/moe/test_deepgemm.py + - tests/kernels/moe/test_batched_deepgemm.py + - tests/kernels/attention/test_deepgemm_attention.py + commands: + - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm + - pytest -v -s kernels/moe/test_deepgemm.py + - pytest -v -s kernels/moe/test_batched_deepgemm.py + - pytest -v -s kernels/attention/test_deepgemm_attention.py + +- label: Kernels (B200) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/" + gpu: b200 + # optional: true + source_file_dependencies: + - csrc/quantization/fp4/ + - csrc/attention/mla/ + - csrc/quantization/cutlass_w8a8/moe/ + - vllm/model_executor/layers/fused_moe/cutlass_moe.py + - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py + - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py + commands: + - nvidia-smi + - python3 examples/offline_inference/basic/chat.py + # Attention + # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py + - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' + - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py + # Quantization + - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' + - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py + - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py \ No newline at end of file diff --git a/.buildkite/test_areas/lm_eval.yaml b/.buildkite/test_areas/lm_eval.yaml new file mode 100644 index 000000000000..9af43e0c375a --- /dev/null +++ b/.buildkite/test_areas/lm_eval.yaml @@ -0,0 +1,46 @@ +group: LM Eval +depends_on: + - image-build +steps: +- label: LM Eval Small Models + timeout_in_minutes: 75 + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + autorun_on_main: true + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + +- label: LM Eval Large Models (4 GPUs)(A100) + gpu: a100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: LM Eval Large Models (4 GPUs)(H100) + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 + +- label: LM Eval Small Models (B200) + timeout_in_minutes: 120 + gpu: b200 + optional: true + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 diff --git a/.buildkite/test_areas/lora.yaml b/.buildkite/test_areas/lora.yaml new file mode 100644 index 000000000000..809b4138f44b --- /dev/null +++ b/.buildkite/test_areas/lora.yaml @@ -0,0 +1,31 @@ +group: LoRA +depends_on: + - image-build +steps: +- label: LoRA %N + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/lora + - tests/lora + commands: + - pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py + parallelism: 4 + + +- label: LoRA TP (Distributed) + timeout_in_minutes: 30 + num_gpus: 4 + source_file_dependencies: + - vllm/lora + - tests/lora + commands: + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # There is some Tensor Parallelism related processing logic in LoRA that + # requires multi-GPU testing for validation. + - pytest -v -s -x lora/test_chatglm3_tp.py + - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py + - pytest -v -s -x lora/test_olmoe_tp.py + - pytest -v -s -x lora/test_gptoss_tp.py \ No newline at end of file diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml new file mode 100644 index 000000000000..252af1e56a10 --- /dev/null +++ b/.buildkite/test_areas/misc.yaml @@ -0,0 +1,165 @@ +group: Miscellaneous +depends_on: + - image-build +steps: +- label: V1 Others + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + # split the test to avoid interference + - pytest -v -s -m 'not cpu_test' v1/core + - pytest -v -s v1/executor + - pytest -v -s v1/kv_offload + - pytest -v -s v1/sample + - pytest -v -s v1/logits_processors + - pytest -v -s v1/worker + - pytest -v -s v1/spec_decode + - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'not cpu_test' v1/metrics + - pytest -v -s v1/test_oracle.py + - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py + # Integration test for streaming correctness (requires special branch). + - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api + - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine + +- label: V1 Others (CPU) + depends_on: ~ + source_file_dependencies: + - vllm/ + - tests/v1 + no_gpu: true + commands: + # split the test to avoid interference + - pytest -v -s -m 'cpu_test' v1/core + - pytest -v -s v1/structured_output + - pytest -v -s v1/test_serial_utils.py + - pytest -v -s -m 'cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'cpu_test' v1/metrics + +- label: Regression + timeout_in_minutes: 20 + source_file_dependencies: + - vllm/ + - tests/test_regression + commands: + - pip install modelscope + - pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional + +- label: Examples + timeout_in_minutes: 45 + working_dir: "/vllm-workspace/examples" + source_file_dependencies: + - vllm/entrypoints + - vllm/multimodal + - examples/ + commands: + - pip install tensorizer # for tensorizer test + - python3 offline_inference/basic/chat.py # for basic + - python3 offline_inference/basic/generate.py --model facebook/opt-125m + - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py + # for multi-modal models + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 + +- label: Metrics, Tracing (2 GPUs) + timeout_in_minutes: 20 + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/v1/tracing + commands: + - "pip install \ + 'opentelemetry-sdk>=1.26.0' \ + 'opentelemetry-api>=1.26.0' \ + 'opentelemetry-exporter-otlp>=1.26.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1'" + - pytest -v -s v1/tracing + +- label: Python-only Installation + depends_on: ~ + timeout_in_minutes: 20 + source_file_dependencies: + - tests/standalone_tests/python_only_compile.sh + - setup.py + commands: + - bash standalone_tests/python_only_compile.sh + +- label: Async Engine, Inputs, Utils, Worker + timeout_in_minutes: 50 + source_file_dependencies: + - vllm/ + - tests/multimodal + - tests/utils_ + commands: + - pytest -v -s -m 'not cpu_test' multimodal + - pytest -v -s utils_ + +- label: Async Engine, Inputs, Utils, Worker, Config (CPU) + depends_on: ~ + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/test_inputs.py + - tests/test_outputs.py + - tests/multimodal + - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ + - tests/tool_parsers + - tests/transformers_utils + - tests/config + no_gpu: true + commands: + - python3 standalone_tests/lazy_imports.py + - pytest -v -s test_inputs.py + - pytest -v -s test_outputs.py + - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers + - pytest -v -s transformers_utils + - pytest -v -s config + +- label: GPT-OSS Eval (B200) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + +- label: Batch Invariance (H100) + timeout_in_minutes: 25 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - vllm/model_executor/layers + - tests/v1/determinism/ + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pip install pytest-timeout pytest-forked + - pytest -v -s v1/determinism/test_batch_invariance.py + - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py \ No newline at end of file diff --git a/.buildkite/test_areas/model_executor.yaml b/.buildkite/test_areas/model_executor.yaml new file mode 100644 index 000000000000..996c8bb8b780 --- /dev/null +++ b/.buildkite/test_areas/model_executor.yaml @@ -0,0 +1,17 @@ +group: Model Executor +depends_on: + - image-build +steps: +- label: Model Executor + timeout_in_minutes: 35 + source_file_dependencies: + - vllm/engine/arg_utils.py + - vllm/config/model.py + - vllm/model_executor + - tests/model_executor + - tests/entrypoints/openai/test_tensorizer_entrypoint.py + commands: + - apt-get update && apt-get install -y curl libsodium23 + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s model_executor + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py diff --git a/.buildkite/test_areas/models_basic.yaml b/.buildkite/test_areas/models_basic.yaml new file mode 100644 index 000000000000..39a5d51c4883 --- /dev/null +++ b/.buildkite/test_areas/models_basic.yaml @@ -0,0 +1,62 @@ +group: Models - Basic +depends_on: + - image-build +steps: +- label: Basic Models Tests (Initialization) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_initialization.py + commands: + # Run a subset of model initialization tests + - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset + +- label: Basic Models Tests (Extra Initialization) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/test_initialization.py + commands: + # Only when vLLM model source is modified - test initialization of a large + # subset of supported models (the complement of the small subset in the above + # test.) Also run if model initialization test file is modified + - pytest -v -s models/test_initialization.py -k 'not test_can_initialize_small_subset' --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Basic Models Tests (Other) + timeout_in_minutes: 45 + source_file_dependencies: + - vllm/ + - tests/models/test_transformers.py + - tests/models/test_registry.py + commands: + - pytest -v -s models/test_transformers.py models/test_registry.py + +- label: Basic Models Test (Other CPU) # 5min + timeout_in_minutes: 10 + source_file_dependencies: + - vllm/ + - tests/models/test_utils.py + - tests/models/test_vision.py + no_gpu: true + commands: + - pytest -v -s models/test_utils.py models/test_vision.py + +- label: Transformers Nightly Models + working_dir: "/vllm-workspace/" + optional: true + soft_fail: true + commands: + - pip install --upgrade git+https://github.com/huggingface/transformers + - pytest -v -s tests/models/test_initialization.py + - pytest -v -s tests/models/test_transformers.py + - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/test_mapping.py + - python3 examples/offline_inference/basic/chat.py + - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl + # Whisper needs spawn method to avoid deadlock + - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper diff --git a/.buildkite/test_areas/models_distributed.yaml b/.buildkite/test_areas/models_distributed.yaml new file mode 100644 index 000000000000..b6bfbf2ddab4 --- /dev/null +++ b/.buildkite/test_areas/models_distributed.yaml @@ -0,0 +1,22 @@ +group: Models - Distributed +depends_on: + - image-build +steps: +- label: Distributed Model Tests (2 GPUs) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/model_executor/model_loader/sharded_state_loader.py + - vllm/model_executor/models/ + - tests/basic_correctness/ + - tests/model_executor/model_loader/test_sharded_state_loader.py + - tests/models/ + commands: + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/language -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' diff --git a/.buildkite/test_areas/models_language.yaml b/.buildkite/test_areas/models_language.yaml new file mode 100644 index 000000000000..f70192c4ebc0 --- /dev/null +++ b/.buildkite/test_areas/models_language.yaml @@ -0,0 +1,91 @@ +group: Models - Language +depends_on: + - image-build +steps: +- label: Language Models Tests (Standard) + timeout_in_minutes: 25 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language + commands: + # Test standard language models, excluding a subset of slow tests + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and (not slow_test)' + +- label: Language Models Tests (Extra Standard) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/language/pooling/test_embedding.py + - tests/models/language/generation/test_common.py + - tests/models/language/pooling/test_classification.py + commands: + # Shard slow subset of standard language models tests. Only run when model + # source is modified, or when specified test files are modified + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and slow_test' --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Tests (Hybrid) %N + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install fast path packages for testing against transformers + # Note: also needed to run plamo2 model in vLLM + - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' + - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' + # Shard hybrid language model tests + - pytest -v -s models/language/generation -m hybrid_model --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Test (Extended Generation) # 80min + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install fast path packages for testing against transformers + # Note: also needed to run plamo2 model in vLLM + - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' + - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' + - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' + +- label: Language Models Test (PPL) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation_ppl_test + commands: + - pytest -v -s models/language/generation_ppl_test + +- label: Language Models Test (Extended Pooling) # 36min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling + commands: + - pytest -v -s models/language/pooling -m 'not core_model' + +- label: Language Models Test (MTEB) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling_mteb_test + commands: + - pytest -v -s models/language/pooling_mteb_test diff --git a/.buildkite/test_areas/models_multimodal.yaml b/.buildkite/test_areas/models_multimodal.yaml new file mode 100644 index 000000000000..fc24068c20a4 --- /dev/null +++ b/.buildkite/test_areas/models_multimodal.yaml @@ -0,0 +1,79 @@ +group: Models - Multimodal +depends_on: + - image-build +steps: +- label: Multi-Modal Models (Standard) # 60min + timeout_in_minutes: 80 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip freeze | grep -E 'torch' + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + +- label: Multi-Modal Processor # 44min + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing/test_tensor_schema.py + +- label: Multi-Modal Accuracy Eval (Small Models) # 50min + timeout_in_minutes: 70 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - vllm/multimodal/ + - vllm/inputs/ + - vllm/v1/core/ + commands: + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 + +- label: Multi-Modal Models (Extended) 1 + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing + +- label: Multi-Modal Models (Extended) 2 + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' + +- label: Multi-Modal Models (Extended) 3 + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' + +# This test is used only in PR development phase to test individual models and should never run on main +- label: Custom Models + optional: true + commands: + - echo 'Testing custom models...' + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* diff --git a/.buildkite/test_areas/plugins.yaml b/.buildkite/test_areas/plugins.yaml new file mode 100644 index 000000000000..60c179aa098e --- /dev/null +++ b/.buildkite/test_areas/plugins.yaml @@ -0,0 +1,34 @@ +group: Plugins +depends_on: + - image-build +steps: +- label: Plugin Tests (2 GPUs) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin + - pip install -e ./plugins/prithvi_io_processor_plugin + - pytest -v -s plugins_tests/test_io_processor_plugins.py + - pip uninstall prithvi_io_processor_plugin -y + # end io_processor plugins test + # begin stat_logger plugins test + - pip install -e ./plugins/vllm_add_dummy_stat_logger + - pytest -v -s plugins_tests/test_stats_logger_plugins.py + - pip uninstall dummy_stat_logger -y + # end stat_logger plugins test + # other tests continue here: + - pytest -v -s plugins_tests/test_scheduler_plugins.py + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins diff --git a/.buildkite/test_areas/pytorch.yaml b/.buildkite/test_areas/pytorch.yaml new file mode 100644 index 000000000000..703c82eb1a91 --- /dev/null +++ b/.buildkite/test_areas/pytorch.yaml @@ -0,0 +1,50 @@ +group: PyTorch +depends_on: + - image-build +steps: +- label: PyTorch Compilation Unit Tests + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/compile + commands: + # Run unit tests defined directly under compile/, + # not including subdirectories, which are usually heavier + # tests covered elsewhere. + # Use `find` to launch multiple instances of pytest so that + # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 + - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\;" + +- label: PyTorch Fullgraph Smoke Test + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/compile + commands: + # Run smoke tests under fullgraph directory, except test_full_graph.py + # as it is a heavy test that is covered in other steps. + # Use `find` to launch multiple instances of pytest so that + # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 + - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\;" + +- label: PyTorch Fullgraph + timeout_in_minutes: 40 + source_file_dependencies: + - vllm/ + - tests/compile + commands: + # fp8 kv scales not supported on sm89, tested on Blackwell instead + - pytest -v -s compile/fullgraph/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time + # Wrap with quotes to escape yaml and avoid starting -k string with a - + - "pytest -v -s compile/distributed/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'" + +- label: Pytorch Nightly Dependency Override Check # 2min + # if this test fails, it means the nightly torch version is not compatible with some + # of the dependencies. Please check the error message and add the package to whitelist + # in /vllm/tools/pre_commit/generate_nightly_torch_test.py + soft_fail: true + source_file_dependencies: + - requirements/nightly_torch_test.txt + commands: + - bash standalone_tests/pytorch_nightly_dependency.sh \ No newline at end of file diff --git a/.buildkite/test_areas/quantization.yaml b/.buildkite/test_areas/quantization.yaml new file mode 100644 index 000000000000..6e89d6af3b8d --- /dev/null +++ b/.buildkite/test_areas/quantization.yaml @@ -0,0 +1,46 @@ +group: Quantization +depends_on: + - image-build +steps: +- label: Quantization + timeout_in_minutes: 90 + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + - tests/quantization + commands: + # temporary install here since we need nightly, will move to requirements/test.in + # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved + # TODO(jerryzh168): resolve the above comment + - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 + - uv pip install --system conch-triton-kernels + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py + +- label: Quantized MoE Test (B200) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - tests/quantization/test_blackwell_moe.py + - vllm/model_executor/models/deepseek_v2.py + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization/compressed_tensors + - vllm/model_executor/layers/quantization/modelopt.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - pytest -s -v tests/quantization/test_blackwell_moe.py + +- label: Quantized Models Test + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/model_executor/layers/quantization + - tests/models/quantization + commands: + - pytest -v -s models/quantization diff --git a/.buildkite/test_areas/samplers.yaml b/.buildkite/test_areas/samplers.yaml new file mode 100644 index 000000000000..ad377148fd07 --- /dev/null +++ b/.buildkite/test_areas/samplers.yaml @@ -0,0 +1,14 @@ +group: Samplers +depends_on: + - image-build +steps: +- label: Samplers Test + timeout_in_minutes: 75 + source_file_dependencies: + - vllm/model_executor/layers + - vllm/sampling_metadata.py + - tests/samplers + - tests/conftest.py + commands: + - pytest -v -s samplers + - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers diff --git a/.buildkite/test_areas/tool_use.yaml b/.buildkite/test_areas/tool_use.yaml new file mode 100644 index 000000000000..69527a121422 --- /dev/null +++ b/.buildkite/test_areas/tool_use.yaml @@ -0,0 +1,13 @@ +group: Tool use +depends_on: + - image-build +steps: +- label: OpenAI-Compatible Tool Use + timeout_in_minutes: 35 + mirror_hardwares: [amdexperimental] + fast_check: false + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use diff --git a/.buildkite/test_areas/weight_loading.yaml b/.buildkite/test_areas/weight_loading.yaml new file mode 100644 index 000000000000..cfc5bb20fe7a --- /dev/null +++ b/.buildkite/test_areas/weight_loading.yaml @@ -0,0 +1,25 @@ +group: Weight Loading +depends_on: + - image-build +steps: +- label: Weight Loading Multiple GPU # 33min + timeout_in_minutes: 45 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt diff --git a/.github/mergify.yml b/.github/mergify.yml index 997a40e18e58..3ad79f93bc7a 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -14,6 +14,52 @@ pull_request_rules: comment: message: "Documentation preview: https://vllm--{{number}}.org.readthedocs.build/en/{{number}}/" +- name: comment-pre-commit-failure + description: Comment on PR when pre-commit check fails + conditions: + - status-failure=pre-commit + - -closed + - -draft + actions: + comment: + message: | + Hi @{{author}}, the pre-commit checks have failed. Please run: + + ```bash + uv pip install pre-commit + pre-commit install + pre-commit run --all-files + ``` + + Then, commit the changes and push to your branch. + + For future commits, `pre-commit` will run automatically on changed files before each commit. + + > [!TIP] + >
+ > Is mypy or markdownlint failing? + >
+ > mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally: + > + > ```bash + > # For mypy (substitute "3.10" with the failing version if needed) + > pre-commit run --hook-stage manual mypy-3.10 + > # For markdownlint + > pre-commit run --hook-stage manual markdownlint + > ``` + >
+ +- name: comment-dco-failure + description: Comment on PR when DCO check fails + conditions: + - status-failure=dco + - -closed + - -draft + actions: + comment: + message: | + Hi @{{author}}, the DCO check has failed. Please click on DCO in the Checks section for instructions on how to resolve this. + - name: label-ci-build description: Automatically apply ci/build label conditions: @@ -140,7 +186,7 @@ pull_request_rules: - files~=^tests/entrypoints/test_context.py - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py - - files~=^vllm/entrypoints/harmony_utils.py + - files~=^vllm/entrypoints/openai/parser/harmony_utils.py - files~=^vllm/entrypoints/tool_server.py - files~=^vllm/entrypoints/tool.py - files~=^vllm/entrypoints/context.py @@ -358,4 +404,4 @@ pull_request_rules: actions: label: add: - - kv-connector \ No newline at end of file + - kv-connector diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index c3e132a536a4..df8910837715 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -13,10 +13,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: '3.12' diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml index 7d565ef9f2e4..629966b95933 100644 --- a/.github/workflows/issue_autolabel.yml +++ b/.github/workflows/issue_autolabel.yml @@ -105,6 +105,31 @@ jobs: } ], }, + cpu: { + // Keyword search - matches whole words only (with word boundaries) + keywords: [ + { + term: "CPU Backend", + searchIn: "title" + }, + { + term: "x86", + searchIn: "title" + }, + { + term: "ARM", + searchIn: "title" + }, + { + term: "Apple Silicon", + searchIn: "title" + }, + { + term: "IBM Z", + searchIn: "title" + }, + ], + }, // Add more label configurations here as needed // example: { // keywords: [...], diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml index a183033c9add..e80a5c0cc80f 100644 --- a/.github/workflows/macos-smoke-test.yml +++ b/.github/workflows/macos-smoke-test.yml @@ -12,7 +12,7 @@ jobs: timeout-minutes: 30 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6.0.1 - uses: astral-sh/setup-uv@v7 with: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index e21d13b8161f..1041653c2f57 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,8 +16,8 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index dca3089f496c..44bf71db5e9d 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -7,13 +7,15 @@ on: jobs: close-issues-and-pull-requests: + # Prevents triggering on forks or other repos + if: github.repository == 'vllm-project/vllm' permissions: issues: write pull-requests: write actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 + - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months diff --git a/CMakeLists.txt b/CMakeLists.txt index 86746a0db4c0..cd52df86e034 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + + # marlin arches for fp16 output + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) + cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + if (MARLIN_ARCHS) # @@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MARLIN_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") - message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} - OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=$PYTHONPATH - ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} + PYTHONPATH=$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log @@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "\nCheck the log for details: " "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") else() - set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} - CACHE STRING "Last run Marlin generate script hash" FORCE) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin generate script hash and arch" FORCE) message(STATUS "Marlin generation completed successfully.") endif() else() message(STATUS "Marlin generation script has not changed, skipping generation.") endif() - file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") @@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_BF16_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) + + if (MARLIN_FP8_ARCHS) + file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC}) + endif() + set(MARLIN_SRCS "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( @@ -604,12 +637,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() message(STATUS "Not building NVFP4 as no compatible archs were found.") @@ -786,7 +822,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH} ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} RESULT_VARIABLE machete_generation_result OUTPUT_VARIABLE machete_generation_output @@ -838,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) set(SRCS - "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" + ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -908,7 +947,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/moe_lora_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -938,8 +976,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") - # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # moe marlin arches + # note that we always set `use_atomic_add=False` for moe marlin now, + # so we don't need 9.0 for bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # moe marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # @@ -949,16 +994,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MOE_MARLIN_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") - message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} - OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=$PYTHONPATH - ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} + PYTHONPATH=$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE moe_marlin_generation_result OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log @@ -971,7 +1018,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "\nCheck the log for details: " "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") else() - set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} CACHE STRING "Last run Marlin MOE generate script hash" FORCE) message(STATUS "Marlin MOE generation completed successfully.") endif() @@ -979,16 +1026,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Marlin MOE generation script has not changed, skipping generation.") endif() - file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") + file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") + list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") set_gencode_flags_for_srcs( - SRCS "${MOE_WNAA16_MARLIN_SRC}" + SRCS "${MARLIN_MOE_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} + set_source_files_properties(${MARLIN_MOE_SRC} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() - - list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) + + if (MARLIN_MOE_FP8_ARCHS) + file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_FP8_SRC}" + CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_FP8_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) + endif() message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() diff --git a/README.md b/README.md index 033e1035d891..26222b815370 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio *Latest News* 🔥 +- [2025/11] We hosted [vLLM Bangkok Meetup](https://luma.com/v0f647nv). We explored vLLM and LMCache inference and low-resource language adaptation with speakers from Embedded LLM, AMD, and Red Hat. Please find the meetup slides [here](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing). - [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link). - [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6). @@ -136,16 +137,19 @@ Compute Resources: - Alibaba Cloud - AMD - Anyscale +- Arm - AWS - Crusoe Cloud - Databricks - DeepInfra - Google Cloud +- IBM - Intel - Lambda Lab - Nebius - Novita AI - NVIDIA +- Red Hat - Replicate - Roblox - RunPod diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md index d1bdb4c43f10..9a9600e08daf 100644 --- a/benchmarks/auto_tune/README.md +++ b/benchmarks/auto_tune/README.md @@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number ``` -#### 2. Maximize Throughput with a Latency Requirement +### 2. Maximize Throughput with a Latency Requirement - **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms. - **Configuration**: @@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=500 ``` -#### 3. Maximize Throughput with Prefix Caching and Latency Requirements +### 3. Maximize Throughput with Prefix Caching and Latency Requirements - **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms. - **Configuration**: diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 56b721cbb402..a245e2022e60 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -18,6 +18,11 @@ MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} +HOSTNAME=$(hostname) +if [[ -z "$HOSTNAME" ]]; then + echo "Error: Failed to determine hostname." >&2 + exit 1 +fi LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" @@ -82,6 +87,7 @@ start_server() { "$MODEL" "--disable-log-requests" "--port" "8004" + "--host" "$HOSTNAME" "--gpu-memory-utilization" "$gpu_memory_utilization" "--max-num-seqs" "$max_num_seqs" "--max-num-batched-tokens" "$max_num_batched_tokens" @@ -96,8 +102,9 @@ start_server() { # This correctly passes each element as a separate argument. if [[ -n "$profile_dir" ]]; then # Start server with profiling enabled - VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ - vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & + local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}" + VLLM_SERVER_DEV_MODE=1 \ + vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 & else # Start server without profiling VLLM_SERVER_DEV_MODE=1 \ @@ -112,7 +119,7 @@ start_server() { # since that we should always have permission to send signal to the server process. kill -0 $server_pid 2> /dev/null || break - RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) + RESPONSE=$(curl -s -X GET "http://${HOSTNAME}:8004/health" -w "%{http_code}" -o /dev/stdout) STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) if [[ "$STATUS_CODE" -eq 200 ]]; then server_started=1 @@ -172,6 +179,7 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 1000 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') @@ -187,7 +195,7 @@ run_benchmark() { request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do # clear prefix cache - curl -X POST http://0.0.0.0:8004/reset_prefix_cache + curl -X POST http://${HOSTNAME}:8004/reset_prefix_cache sleep 5 bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt" vllm bench serve \ @@ -203,6 +211,7 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') @@ -303,6 +312,7 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 \ --profile &> "$bm_log" else diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 4021fede7215..831b76b66e09 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -620,7 +620,7 @@ def get_tokenizer( kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: - from vllm.transformers_utils.tokenizer import MistralTokenizer + from vllm.tokenizers.mistral import MistralTokenizer except ImportError as e: raise ImportError( "MistralTokenizer requires vllm package.\n" diff --git a/benchmarks/benchmark_hash.py b/benchmarks/benchmark_hash.py new file mode 100644 index 000000000000..08cdc012d652 --- /dev/null +++ b/benchmarks/benchmark_hash.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Micro benchmark comparing built-in hash(), SHA-256, and xxHash. + +This focuses on a single test payload shaped like the prefix-cache hash input: + (32-byte bytes object, 32-int tuple) + +Usage: + python benchmarks/hash_micro_benchmark.py --iterations 20000 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import time +from collections.abc import Callable, Iterable + +from vllm.utils.hashing import sha256, xxhash + + +def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]: + """Generate a deterministic test payload.""" + random.seed(seed) + bytes_data = bytes(random.getrandbits(8) for _ in range(32)) + int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32)) + return (bytes_data, int_tuple) + + +def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int): + """Return (avg_seconds, std_seconds) for hashing `data` `iterations` times.""" + times: list[float] = [] + + # Warm-up to avoid first-run noise. + for _ in range(200): + func(data) + + for _ in range(iterations): + start = time.perf_counter() + func(data) + end = time.perf_counter() + times.append(end - start) + + avg = statistics.mean(times) + std = statistics.stdev(times) if len(times) > 1 else 0.0 + return avg, std + + +def _run_benchmarks( + benchmarks: Iterable[tuple[str, Callable[[tuple], object]]], + data: tuple, + iterations: int, +): + """Yield (name, avg, std) for each benchmark, skipping unavailable ones.""" + for name, func in benchmarks: + try: + avg, std = _benchmark_func(func, data, iterations) + except ModuleNotFoundError as exc: + print(f"Skipping {name}: {exc}") + continue + yield name, avg, std + + +def builtin_hash(data: tuple) -> int: + """Wrapper for Python's built-in hash().""" + return hash(data) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--iterations", + type=int, + default=10_000, + help="Number of measured iterations per hash function.", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for test payload." + ) + args = parser.parse_args() + + data = _generate_test_data(args.seed) + benchmarks = ( + ("SHA256 (pickle)", sha256), + ("xxHash (pickle)", xxhash), + ("built-in hash()", builtin_hash), + ) + + print("=" * 60) + print("HASH FUNCTION MICRO BENCHMARK") + print("=" * 60) + print("Test data: (32-byte bytes object, 32-int tuple)") + print(f"Iterations: {args.iterations:,}") + print("=" * 60) + + results = list(_run_benchmarks(benchmarks, data, args.iterations)) + builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None) + + print("\nResults:") + for name, avg, std in results: + print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs") + + if builtin_entry: + _, builtin_avg, _ = builtin_entry + print("\n" + "=" * 60) + print("SUMMARY (relative to built-in hash())") + print("=" * 60) + for name, avg, _ in results: + if name == "built-in hash()": + continue + speed_ratio = avg / builtin_avg + print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()") + else: + print("\nBuilt-in hash() result missing; cannot compute speed ratios.") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index dedb564fffac..b5373d383b54 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -32,12 +32,11 @@ def benchmark_propose(args): model_config = ModelConfig( model="facebook/opt-125m", - task="generate", max_model_len=args.num_token + args.num_spec_token, tokenizer="facebook/opt-125m", tokenizer_mode="auto", dtype="auto", - seed=None, + seed=0, trust_remote_code=False, ) proposer = NgramProposer( @@ -108,7 +107,10 @@ def benchmark_batched_propose(args): device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group diff --git a/benchmarks/benchmark_prefix_block_hash.py b/benchmarks/benchmark_prefix_block_hash.py new file mode 100644 index 000000000000..8bcd8af0d310 --- /dev/null +++ b/benchmarks/benchmark_prefix_block_hash.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Simple benchmark to compare prefix-cache block hashing algorithms. + +Example: + python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import sys +import time +from collections.abc import Callable, Iterable, Sequence + +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash + +SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor") + + +def _generate_blocks( + num_blocks: int, block_size: int, vocab_size: int, seed: int +) -> list[list[int]]: + rng = random.Random(seed) + return [ + [rng.randrange(vocab_size) for _ in range(block_size)] + for _ in range(num_blocks) + ] + + +def _hash_all_blocks( + hash_fn: Callable[[object], bytes], + blocks: Iterable[Sequence[int]], +) -> float: + parent_hash: BlockHash | None = None + start = time.perf_counter() + for block in blocks: + parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None) + end = time.perf_counter() + return end - start + + +def _benchmark( + hash_algo: str, + blocks: list[list[int]], + trials: int, +) -> tuple[float, float, float] | None: + try: + hash_fn = get_hash_fn_by_name(hash_algo) + init_none_hash(hash_fn) + timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)] + except ModuleNotFoundError as exc: + print(f"Skipping {hash_algo}: {exc}", file=sys.stderr) + return None + + avg = statistics.mean(timings) + best = min(timings) + # throughput: tokens / second + tokens_hashed = len(blocks) * len(blocks[0]) + throughput = tokens_hashed / best + return avg, best, throughput + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.") + parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.") + parser.add_argument( + "--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)." + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument( + "--trials", type=int, default=5, help="Number of timed trials per algorithm." + ) + parser.add_argument( + "--algorithms", + nargs="+", + default=SUPPORTED_ALGOS, + choices=SUPPORTED_ALGOS, + help="Hash algorithms to benchmark.", + ) + args = parser.parse_args() + + blocks = _generate_blocks( + args.num_blocks, args.block_size, args.vocab_size, args.seed + ) + print( + f"Benchmarking {len(args.algorithms)} algorithms on " + f"{args.num_blocks} blocks (block size={args.block_size})." + ) + + for algo in args.algorithms: + result = _benchmark(algo, blocks, args.trials) + if result is None: + continue + + avg, best, throughput = result + print( + f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s " + f"throughput: {throughput / 1e6:.2f}M tokens/s" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 28fc383a318d..e6391134ff93 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -40,7 +40,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser try: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer except ImportError: from backend_request_func import get_tokenizer diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 55001cf3722a..33aca831883a 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -46,7 +46,7 @@ from transformers import PreTrainedTokenizerBase try: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer except ImportError: from backend_request_func import get_tokenizer @@ -574,7 +574,7 @@ async def limited_request_func(request_func_input, pbar): ) print( "{:<40} {:<10.2f}".format( - "Total Token throughput (tok/s):", metrics.total_token_throughput + "Total token throughput (tok/s):", metrics.total_token_throughput ) ) @@ -963,8 +963,7 @@ def create_argument_parser(): parser.add_argument( "--profile", action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) parser.add_argument( "--result-dir", diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index d809bf1db8cb..fb3329975cee 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -14,6 +14,9 @@ import vllm._custom_ops as ops from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) @dataclass @@ -22,6 +25,7 @@ class bench_params_t: hidden_size: int add_residual: bool dtype: torch.dtype + group_size: list[int] def description(self): return ( @@ -29,6 +33,7 @@ def description(self): f"x D {self.hidden_size} " f"x R {self.add_residual} " f"x DT {self.dtype}" + f"x GS {self.group_size}" ) @@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]: HIDDEN_SIZES = list(range(1024, 8129, 1024)) ADD_RESIDUAL = [True, False] DTYPES = [torch.bfloat16, torch.float] + GROUP_SIZES = [[1, 64], [1, 128]] - combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES) bench_params = list( - map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations) ) return bench_params @@ -52,6 +58,7 @@ def unfused_int8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -69,6 +76,7 @@ def unfused_fp8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -81,23 +89,63 @@ def unfused_fp8_impl( torch_out, _ = ops.scaled_fp8_quant(torch_out) +def unfused_groupwise_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + + def fused_impl( rms_norm_layer: RMSNorm, # this stores the weights x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): out, _ = ops.rms_norm_dynamic_per_token_quant( x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual ) +def fused_groupwise_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + out, _ = ops.rms_norm_per_block_quant( + x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + group_size, + residual=residual, + is_scale_transposed=True, + ) + + # Bench functions def bench_fn( rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, quant_dtype: torch.dtype, + group_size: list[int], label: str, sub_label: str, fn: Callable, @@ -110,10 +158,11 @@ def bench_fn( "x": x, "residual": residual, "quant_dtype": quant_dtype, + "group_size": group_size, "fn": fn, } return TBenchmark.Timer( - stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)", globals=globals, label=label, sub_label=sub_label, @@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, unfused_int8_impl, @@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, unfused_fp8_impl, @@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, fused_impl, @@ -189,6 +241,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, fused_impl, @@ -196,6 +249,36 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu ) ) + # unfused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_groupwise_fp8_impl, + "unfused_groupwise_fp8_impl", + ) + ) + + # fused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_groupwise_impl, + "fused_groupwise_fp8_impl", + ) + ) + print_timers(timers) return timers diff --git a/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py new file mode 100644 index 000000000000..04921dafbdbe --- /dev/null +++ b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from enum import Enum +from itertools import product +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +from .utils import ArgPool, Bench, CudaGraphBenchParams + +GROUP_SIZE = 128 +FLOAT8_T = torch.float8_e4m3fn + + +def print_timers(timers: list[TMeasurement], cuda_graph_nops: int): + print( + f"Note : The timings reported above is for {cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {cuda_graph_nops} for single invocation " + "timings." + ) + compare = TBenchmark.Compare(timers) + compare.print() + + +class ImplType(Enum): + SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1 + REFERENCE = 2 + + def get_impl(self): + if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return silu_mul_per_token_group_quant_fp8_colmajor + elif self == ImplType.REFERENCE: + return reference + raise ValueError(f"Unrecognized ImplType {self}") + + +@dataclass +class BenchmarkTensors: + input: torch.Tensor + output: torch.Tensor + + # Reference act output tensor + ref_act_out: torch.Tensor + ref_quant_out: torch.Tensor + + @staticmethod + def make(T: int, N: int) -> "BenchmarkTensors": + assert T % GROUP_SIZE == 0 + assert N % (GROUP_SIZE * 2) == 0 + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + # silu_mul_per_token_group_quant_fp8_colmajor output. + output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to( + FLOAT8_T + ) + + # reference output. + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + ref_quant_out = torch.empty( + (T, N // 2), dtype=torch.bfloat16, device="cuda" + ).to(FLOAT8_T) + + return BenchmarkTensors( + input=input, + output=output, + ref_act_out=ref_act_out, + ref_quant_out=ref_quant_out, + ) + + @property + def T(self): + return self.input.size(0) + + @property + def N(self): + return self.input.size(1) + + def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]: + if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return { + "input": self.input, + "output": self.output, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + elif impl_type == ImplType.REFERENCE: + return { + "input": self.input, + "act_out": self.ref_act_out, + "quant_out": self.ref_quant_out, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + raise ValueError(f"Unrecognized impl_type {impl_type}") + + +def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + assert quant_out.size() == x.size() + # Allocate the scale tensor column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_q = quant_out + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_T) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference( + input: torch.Tensor, + act_out: torch.Tensor, + quant_out: torch.Tensor, + use_ue8m0: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops._C.silu_and_mul(act_out, input) + return reference_quant(act_out, quant_out, use_ue8m0) + + +def bench_impl( + bench_tensors: list[BenchmarkTensors], impl_type: ImplType +) -> TMeasurement: + T = bench_tensors[0].T + N = bench_tensors[0].N + + arg_pool_size = len(bench_tensors) + kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors] + + # warmup + for kwargs in kwargs_list: + impl_type.get_impl()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + cuda_graph_params = None + cuda_graph_params = CudaGraphBenchParams(arg_pool_size) + timer = None + with Bench( + cuda_graph_params, + "silu-mul-quant", + f"num_tokens={T}, N={N}", + impl_type.name, + impl_type.get_impl(), + **kwargs, + ) as bench: + timer = bench.run() + return timer + + +def test_correctness(T: int, N: int): + print(f"Testing num_tokens={T}, N={N} ...") + + bench_tensor = BenchmarkTensors.make(T, N) + + def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]: + return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl)) + + # reference output + ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE) + + # test ouptut + out_q, out_s = output_from_impl( + ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + + torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32)) + torch.testing.assert_close(ref_out_s, out_s) + + +def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]: + timers = [] + for N, T in product(Ns, Ts): + test_correctness(T, N) + + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(T, N) for _ in range(arg_pool_size) + ] + + silu_mul_quant_timer = bench_impl( + bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + timers.append(silu_mul_quant_timer) + reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE) + timers.append(reference_timer) + + print_timers( + [silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size + ) + + print_timers(timers, cuda_graph_nops=arg_pool_size) + + return timers + + +if __name__ == "__main__": + T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)] + N = [2048, 4096, 8192] + + print(f"T = {T}, N = {N}") + run(T, N, arg_pool_size=8) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 8787724d77cf..ac78c019a59e 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: b_q_weight=w_q, b_bias=None, b_scales=w_s, + a_scales=None, global_scale=None, b_zeros=w_zp, g_idx=g_idx, diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 12ca9214b1f9..48d790aec9e0 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -263,7 +263,7 @@ def gen_allspark_params(): results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -273,7 +273,7 @@ def gen_allspark_params(): results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_mla_k_concat.py b/benchmarks/kernels/benchmark_mla_k_concat.py new file mode 100644 index 000000000000..fb3b6c8f1200 --- /dev/null +++ b/benchmarks/kernels/benchmark_mla_k_concat.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation +in MLA (Multi-head Latent Attention) prefill. + +This validates that the optimization from commit 8d4142bd is beneficial across +various batch sizes, not just the originally tested batch size of 32768. +""" + +import time +from collections.abc import Callable + +import torch + +# DeepSeek-V3 MLA dimensions +NUM_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +PE_DIM = 64 + + +def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Original torch.cat approach with expand.""" + return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + +def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Optimized direct copy approach (avoids expand + cat overhead).""" + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + +def benchmark_method( + method: Callable, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + num_warmup: int = 10, + num_iters: int = 100, +) -> float: + """Benchmark a concatenation method and return mean latency in ms.""" + # Warmup + for _ in range(num_warmup): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + end = time.perf_counter() + + return (end - start) / num_iters * 1000 # Convert to ms + + +@torch.inference_mode() +def run_benchmark(dtype: torch.dtype, dtype_name: str): + """Run benchmark for a specific dtype.""" + torch.set_default_device("cuda") + + # Batch sizes to test (powers of 2 from 32 to 65536) + batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] + + print("=" * 80) + print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation") + print("=" * 80) + print( + f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], " + f"k_pe=[B, 1, {PE_DIM}]" + ) + print(f"dtype: {dtype_name}") + print() + print( + f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | " + f"{'Speedup':>8} | {'Reduction':>10}" + ) + print("-" * 70) + + results = [] + for batch_size in batch_sizes: + # Create input tensors (generate in float32 then convert for FP8 compatibility) + k_nope = torch.randn( + batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + k_pe = torch.randn( + batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + + # Benchmark both methods + cat_time = benchmark_method(cat_method, k_nope, k_pe) + direct_time = benchmark_method(direct_copy_method, k_nope, k_pe) + + speedup = cat_time / direct_time + reduction = (1 - direct_time / cat_time) * 100 + + results.append((batch_size, cat_time, direct_time, speedup, reduction)) + + print( + f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | " + f"{speedup:>7.2f}x | {reduction:>9.1f}%" + ) + + print("=" * 80) + + # Summary statistics + speedups = [r[3] for r in results] + print("\nSpeedup summary:") + print(f" Min: {min(speedups):.2f}x") + print(f" Max: {max(speedups):.2f}x") + print(f" Mean: {sum(speedups) / len(speedups):.2f}x") + + # Find crossover point + crossover_batch = None + for batch_size, _, _, speedup, _ in results: + if speedup >= 1.0: + crossover_batch = batch_size + break + + print("\nConclusion:") + if crossover_batch: + print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}") + # Filter for large batches (>= 512 which is typical for prefill) + large_batch_speedups = [r[3] for r in results if r[0] >= 512] + if large_batch_speedups: + avg_large = sum(large_batch_speedups) / len(large_batch_speedups) + print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x") + print(" - MLA prefill typically uses large batches, so optimization is effective") + + return results + + +@torch.inference_mode() +def main(): + # Test bfloat16 + print("\n") + run_benchmark(torch.bfloat16, "bfloat16") + + # Test float8_e4m3fn + print("\n") + run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py index f540cff6261a..5f9a131f79b0 100644 --- a/benchmarks/kernels/benchmark_moe_align_block_size.py +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: num_tokens_range = [1, 16, 256, 4096] num_experts_range = [16, 64, 224, 256, 280, 512] topk_range = [1, 2, 8] -configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) +ep_size_range = [1, 8] +configs = list( + itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range) +) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["num_tokens", "num_experts", "topk"], + x_names=["num_tokens", "num_experts", "topk", "ep_size"], x_vals=configs, line_arg="provider", line_vals=["vllm"], @@ -38,16 +41,26 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: args={}, ) ) -def benchmark(num_tokens, num_experts, topk, provider): +def benchmark(num_tokens, num_experts, topk, ep_size, provider): """Benchmark function for Triton.""" block_size = 256 + torch.cuda.manual_seed_all(0) topk_ids = get_topk_ids(num_tokens, num_experts, topk) + e_map = None + if ep_size != 1: + local_e = num_experts // ep_size + e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + quantiles = [0.5, 0.2, 0.8] if provider == "vllm": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: moe_align_block_size(topk_ids, block_size, num_experts), + lambda: moe_align_block_size( + topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True + ), quantiles=quantiles, ) diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py index 83bd91917508..09de5fa822f8 100644 --- a/benchmarks/kernels/benchmark_mrope.py +++ b/benchmarks/kernels/benchmark_mrope.py @@ -99,7 +99,6 @@ def benchmark_mrope( # the parameters to compute the q k v size based on tp_size mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=head_dim, max_position=max_position, is_neox_style=is_neox_style, rope_parameters=rope_parameters, diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 074b7a440b61..7a1bc050bb33 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device): def benchmark(batch_size, seq_len, num_heads, provider): dtype = torch.bfloat16 max_position = 8192 - base = 10000 - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope_parameters = {"partial_rotary_factor": rotary_dim / head_size} + rope = get_rope(head_size, max_position, is_neox_style, rope_parameters) rope = rope.to(dtype=dtype, device=device) cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index fbbb03c5ed46..85b286f8d8d0 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON endif() # Build ACL with CMake - set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF") - set(CMAKE_BUILD_TYPE "Release") - set(ARM_COMPUTE_ARCH "armv8.2-a") - set(ARM_COMPUTE_ENABLE_ASSERTS "OFF") - set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ARM_COMPUTE_ENABLE_OPENMP "ON") - set(ARM_COMPUTE_ENABLE_WERROR "OFF") - set(ARM_COMPUTE_BUILD_EXAMPLES "OFF") - set(ARM_COMPUTE_BUILD_TESTING "OFF") - set(_cmake_config_cmd ${CMAKE_COMMAND} -G Ninja -B build -DARM_COMPUTE_BUILD_SHARED_LIB=OFF diff --git a/cmake/utils.cmake b/cmake/utils.cmake index ca0062ba4fab..bdb2ba74d944 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR) run_python(_VLLM_TORCH_GOMP_PATH " import os, glob -try: - import torch - torch_pkg = os.path.dirname(torch.__file__) - site_root = os.path.dirname(torch_pkg) - torch_libs = os.path.join(site_root, 'torch.libs') - print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0]) -except: - print('') +import torch +torch_pkg = os.path.dirname(torch.__file__) +site_root = os.path.dirname(torch_pkg) + +# Search both torch.libs and torch/lib +roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')] +candidates = [] +for root in roots: + if not os.path.isdir(root): + continue + candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*'))) + +print(candidates[0] if candidates else '') " - "failed to probe torch.libs for libgomp") + "failed to probe for libgomp") if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}") return() @@ -495,7 +500,13 @@ function (define_extension_target MOD_NAME) set(SOABI_KEYWORD "") endif() - if (ARG_USE_SABI) + run_python(IS_FREETHREADED_PYTHON + "import sysconfig; print(1 if sysconfig.get_config_var(\"Py_GIL_DISABLED\") else 0)" + "Failed to determine whether interpreter is free-threaded") + + # Free-threaded Python doesn't yet support the stable ABI (see PEP 803/809), + # so avoid using the stable ABI under free-threading only. + if (ARG_USE_SABI AND NOT IS_FREETHREADED_PYTHON) Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}") else() Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}") diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 229d9862fb67..27d1e990c611 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel( scalar_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, - const uint head_size) { + const uint head_size, const uint prefix_head_stride, + const uint output_head_stride) { using pack_128b_t = uint4; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; @@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel( const uint head_idx = token_head_idx % num_heads; const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. - const uint head_offset = - token_idx * num_heads * head_size + head_idx * head_size; - const scalar_t* prefix_head_ptr = prefix_output + head_offset; - const scalar_t* suffix_head_ptr = suffix_output + head_offset; - scalar_t* output_head_ptr = output + head_offset; + const uint src_head_offset = token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride; + const uint dst_head_offset = token_idx * num_heads * output_head_stride + + head_idx * output_head_stride; + const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; + const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; + scalar_t* output_head_ptr = output + dst_head_offset; float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; @@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel( reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + num_heads, head_size, prefix_head_stride, output_head_stride); \ } /*@brief Merges the attention states from prefix and suffix @@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); + const uint prefix_head_stride = prefix_output.stride(1); + const uint output_head_stride = output.stride(1); const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); - TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, - "output heads must be contiguous in memory"); - TORCH_CHECK( - prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, - "prefix_output heads must be contiguous in memory"); - TORCH_CHECK( - suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, - "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/csrc/cache.h b/csrc/cache.h index f2a5ec0acf5c..cbe44c09eb62 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -58,6 +59,15 @@ void cp_gather_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); +// Gather and upconvert FP8 KV cache to BF16 workspace +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size); + // Indexer K quantization and cache function void indexer_k_quant_and_cache( torch::Tensor& k, // [num_tokens, head_dim] @@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file + const torch::Tensor& cu_seq_lens); // [batch_size + 1] diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a5457206c70..f11c5f24c12e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel( const int quant_block_size, // quantization block size const int cache_block_size, // cache block size const int cache_stride, // stride for each token in kv_cache - const bool use_ue8m0 // use ue8m0 scale format + + const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; @@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache( } namespace vllm { + +// Gather and upconvert FP8 KV cache tokens to BF16 workspace +// Similar to cp_gather_cache but specifically for FP8->BF16 conversion +__global__ void cp_gather_and_upconvert_fp8_kv_cache( + const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ seq_lens, // [BATCH] + const int32_t* __restrict__ workspace_starts, // [BATCH] + const int32_t block_size, const int32_t head_dim, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = workspace_starts[bid]; + const int32_t seq_len = seq_lens[bid]; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths + dst += seq_start * dst_entry_stride; + + const int tid = threadIdx.x; + + // Process each token in this split + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + const uint8_t* token_ptr = + src_cache + block_id * cache_block_stride + offset * cache_entry_stride; + __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; + + // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) + const uint8_t* no_pe_ptr = token_ptr; + const float* scales_ptr = reinterpret_cast(token_ptr + 512); + const __nv_bfloat16* rope_ptr = + reinterpret_cast(token_ptr + 512 + 16); + + // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) + if (tid < 512) { + // FP8 dequantization + const int tile = tid >> 7; // each tile is 128 elements + const float scale = scales_ptr[tile]; + const uint8_t val = no_pe_ptr[tid]; + dst_ptr[tid] = + fp8::scaled_convert<__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); + } else if (tid < 576) { + // Rope copy (64 bf16 elements) + const int rope_idx = tid - 512; + dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; + } + + // Move to next token + offset += 1; + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} + template // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // block_size. @@ -1202,6 +1280,57 @@ void cp_gather_cache( } } +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t head_dim = dst.size(1); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); + TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, + "workspace_starts must be int32"); + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == seq_lens.device(), + "src_cache and seq_lens must be on the same device"); + TORCH_CHECK(src_cache.device() == workspace_starts.device(), + "src_cache and workspace_starts must be on the same device"); + + TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); + TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); + TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(576); + + vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( + src_cache.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), + block_table.data_ptr(), seq_lens.data_ptr(), + workspace_starts.data_ptr(), block_size, head_dim, + block_table_stride, cache_block_stride, cache_entry_stride, + dst_entry_stride); +} + // Macro to dispatch the kernel based on the data type. #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ vllm::indexer_k_quant_and_cache_kernel \ diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 92f8bee5a47a..02c722ba031a 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata( input.casual = casual; input.isa = isa; input.enable_kv_split = enable_kv_split; - TORCH_CHECK(casual, "Only supports casual mask for now."); VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 98f55d7c014b..e3e077b845f4 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -186,7 +186,7 @@ struct AttentionMetadata { // - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2 // * q_tile_size * 4, partial output, max + sum (float) // Reduction scratchpad contains: -// - flags: bool array to indicate wether the split is finished +// - flags: bool array to indicate whether the split is finished // - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size // - max, sum: 2 * split_num * q_tile_size * 4 class AttentionScratchPad { @@ -1246,14 +1246,8 @@ class AttentionMainLoop { // rescale sum and partial outputs if (need_rescale) { // compute rescale factor -#ifdef DEFINE_FAST_EXP - vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); - rescale_factor_vec = fast_exp(rescale_factor_vec); - rescale_factor = rescale_factor_vec.get_last_elem(); -#else rescale_factor = std::exp(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); -#endif // rescale sum new_sum_val += rescale_factor * init_sum_val; @@ -1889,15 +1883,8 @@ class AttentionMainLoop { : curr_output_buffer; float rescale_factor = final_max > curr_max ? curr_max - final_max : final_max - curr_max; - -#ifdef DEFINE_FAST_EXP - vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); - rescale_factor_vec = fast_exp(rescale_factor_vec); - rescale_factor = rescale_factor_vec.get_last_elem(); -#else rescale_factor = std::exp(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); -#endif local_sum[head_idx] = final_max > curr_max ? final_sum + rescale_factor * curr_sum diff --git a/csrc/cpu/cpu_attn_macros.h b/csrc/cpu/cpu_attn_macros.h index 6458e4341937..35716a0790ab 100644 --- a/csrc/cpu/cpu_attn_macros.h +++ b/csrc/cpu/cpu_attn_macros.h @@ -60,4 +60,54 @@ #endif +#ifdef __aarch64__ + // Implementation copied from Arm Optimized Routines (expf AdvSIMD) + // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c + #include + #define DEFINE_FAST_EXP \ + const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \ + const float ln2_hi = 0x1.62e4p-1f; \ + const float ln2_lo = 0x1.7f7d1cp-20f; \ + const float c0 = 0x1.0e4020p-7f; \ + const float c2 = 0x1.555e66p-3f; \ + const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \ + const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \ + const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \ + const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \ + const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \ + const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \ + const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \ + const float32x4_t inf = \ + vdupq_n_f32(std::numeric_limits::infinity()); \ + const float32x4_t zero = vdupq_n_f32(0.0f); \ + auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \ + float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \ + float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \ + r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \ + uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \ + float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \ + float32x4_t r2 = vmulq_f32(r, r); \ + float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \ + float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \ + q = vfmaq_f32(q, p, r2); \ + p = vmulq_f32(c4, r); \ + float32x4_t poly = vfmaq_f32(p, q, r2); \ + poly = vfmaq_f32(scale, poly, scale); \ + const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \ + const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \ + poly = vbslq_f32(hi_mask, inf, poly); \ + return vbslq_f32(lo_mask, zero, poly); \ + }; \ + auto fast_exp = [&](vec_op::FP32Vec16& vec) \ + __attribute__((always_inline)) { \ + float32x4x4_t result; \ + result.val[0] = neon_expf(vec.reg.val[0]); \ + result.val[1] = neon_expf(vec.reg.val[1]); \ + result.val[2] = neon_expf(vec.reg.val[2]); \ + result.val[3] = neon_expf(vec.reg.val[3]); \ + return vec_op::FP32Vec16(result); \ + }; + +#endif // __aarch64__ + #endif \ No newline at end of file diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 5199ba2af024..3dacfc7b2b7a 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -51,12 +51,13 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { if (node_id != -1) { node_ids.insert(node_id); } - TORCH_WARN(node_id == mem_node_id, "CPU ", cpu_id, " is on NUMA node ", - node_id, ", but CPU ", omp_cpu_ids.front(), - " is on NUMA node ", mem_node_id, - ". All CPUs should be on the same NUMA node for optimal " - "performance. Memory will be bound to NUMA node ", - mem_node_id, "."); + if (node_id != mem_node_id) { + TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ", + omp_cpu_ids.front(), " is on NUMA node ", mem_node_id, + ". All CPUs should be on the same NUMA node for optimal " + "performance. Memory will be bound to NUMA node ", + mem_node_id, "."); + } } // Concatenate all node_ids into a single comma-separated string if (!node_ids.empty()) { diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index e1d131e4a785..de0c505b7a62 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -118,6 +118,24 @@ } \ } +#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \ + if (expr) { \ + constexpr bool const_expr = true; \ + __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + __VA_ARGS__(); \ + } + +#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \ + if (group_size == 128) { \ + constexpr int const_group_size = 128; \ + __VA_ARGS__(); \ + } else if (group_size == 64) { \ + constexpr int const_group_size = 64; \ + __VA_ARGS__(); \ + } + #define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \ switch (NUM_DIMS) { \ case 2: { \ diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index df47bb8dd1d7..58dc40201688 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -93,16 +93,16 @@ torch::Tensor dynamic_4bit_int_moe_cpu( } auto Y_all = at::empty({offsets[E], H}, x_c.options()); - at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) { c10::InferenceMode guard; - for (int64_t e = e_begin; e < e_end; ++e) { - const int64_t te = counts[e]; - if (te == 0) { + for (int64_t e = 0; e < E; ++e) { + int64_t start = std::max(offsets[e], idx_begin); + int64_t end = std::min(offsets[e + 1], idx_end); + int64_t te = end - start; + if (te <= 0) { continue; } - const int64_t start = offsets[e]; - auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e); diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 69b4c1fb11d1..5fa367abd96f 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) { return cuda_cast(sigmoid_accurate(f)); } -template +template +__device__ inline T apply_scoring(T val) { + if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } else { + return val; + } +} + +template __device__ void topk_with_k2(T* output, T const* input, T const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, - int const num_experts_per_group, - int const scoring_func) { + int const num_experts_per_group) { // Get the top2 per thread T largest = neg_inf(); T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; if (value > largest) { @@ -472,17 +476,11 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; largest = value; } } - - __syncwarp(); // Ensure all threads have valid data before reduction // Get the top2 warpwise T max1 = cg::reduce(tile, largest, cg::greater()); @@ -501,13 +499,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } -template +template __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, - int64_t const num_experts_per_group, - int const scoring_func) { + int64_t const num_experts_per_group) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -525,21 +522,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, group_bias, tile, lane_id, - num_experts_per_group, scoring_func); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool renormalize, - double routed_scaling_factor, int scoring_func) { + double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = @@ -549,6 +546,11 @@ __global__ void group_idx_and_topk_idx_kernel( topk_values += case_id * topk; topk_indices += case_id * topk; + constexpr bool kUseStaticNGroup = (NGroup > 0); + // use int32 to avoid implicit conversion + int32_t const n_group_i32 = + kUseStaticNGroup ? NGroup : static_cast(n_group); + int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of(num_experts_per_group); @@ -574,17 +576,17 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx - int32_t target_num_min = WARP_SIZE - n_group + topk_group; + int32_t target_num_min = + WARP_SIZE - n_group_i32 + static_cast(topk_group); // The check is necessary to avoid abnormal input - if (lane_id < n_group && is_finite(group_scores[lane_id])) { + if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } - int count_equal_to_top_value = WARP_SIZE - n_group; + int count_equal_to_top_value = WARP_SIZE - n_group_i32; int pre_count_equal_to_top_value = 0; // Use loop to find the largset top_group while (count_equal_to_top_value < target_num_min) { - __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { value = neg_inf(); @@ -604,7 +606,7 @@ __global__ void group_idx_and_topk_idx_kernel( int count_equalto_topkth_group = 0; bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { - for (int i_group = 0; i_group < n_group; i_group++) { + auto process_group = [&](int i_group) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && (count_equalto_topkth_group < num_equalto_topkth_group))) { @@ -613,11 +615,10 @@ __global__ void group_idx_and_topk_idx_kernel( i += WARP_SIZE) { T candidates = neg_inf(); if (i < num_experts_per_group) { - // Apply scoring function (if any) and add bias + // apply scoring function (if any) and add bias T input = scores[offset + i]; if (is_finite(input)) { - T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) - : input; + T score = apply_scoring(input); candidates = score + bias[offset + i]; } } @@ -627,12 +628,21 @@ __global__ void group_idx_and_topk_idx_kernel( count_equalto_topkth_group++; } } + }; + + if constexpr (kUseStaticNGroup) { +#pragma unroll + for (int i_group = 0; i_group < NGroup; ++i_group) { + process_group(i_group); + } + } else { + for (int i_group = 0; i_group < n_group_i32; ++i_group) { + process_group(i_group); + } } queue.done(); - __syncwarp(); // Get the topk_idx queue.dumpIdx(s_topk_idx); - __syncwarp(); } // Load the valid score value @@ -646,12 +656,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { // Load the score value (without bias) for normalization T input = scores[s_topk_idx[i]]; - value = - (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input; + value = apply_scoring(input); s_topk_value[i] = value; } - topk_sum += - cg::reduce(tile, cuda_cast(value), cg::plus()); + if (renormalize) { + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); + } } } @@ -660,13 +671,9 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value; - if (renormalize) { - value = cuda_cast(s_topk_value[i]) / topk_sum * - routed_scaling_factor; - } else { - value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; - } + float base = cuda_cast(s_topk_value[i]); + float value = renormalize ? (base / topk_sum * routed_scaling_factor) + : (base * routed_scaling_factor); topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } @@ -684,6 +691,45 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } +template +inline void launch_group_idx_and_topk_kernel( + cudaLaunchConfig_t const& config, T* scores, T* group_scores, + float* topk_values, IdxT* topk_indices, T const* bias, + int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, + int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool const renormalize, + double const routed_scaling_factor) { + auto launch = [&](auto* kernel_instance2) { + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + renormalize, routed_scaling_factor); + }; + + switch (n_group) { + case 4: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 8: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 16: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 32: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + default: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + } +} + template void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, @@ -694,7 +740,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchConfig_t config; config.gridDim = topk_with_k2_num_blocks; config.blockDim = BLOCK_SIZE; @@ -705,16 +750,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts / n_group, - scoring_func); + auto const sf = static_cast(scoring_func); + int64_t const num_experts_per_group = num_experts / n_group; + auto launch_topk_with_k2 = [&](auto* kernel_instance1) { + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts_per_group); + }; + switch (sf) { + case SCORING_NONE: { + auto* kernel_instance1 = &topk_with_k2_kernel; + launch_topk_with_k2(kernel_instance1); + break; + } + case SCORING_SIGMOID: { + auto* kernel_instance1 = &topk_with_k2_kernel; + launch_topk_with_k2(kernel_instance1); + break; + } + default: + // should be guarded by higher level checks. + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; size_t dynamic_smem_in_bytes = warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; config.gridDim = topk_with_k_group_num_blocks; config.blockDim = BLOCK_SIZE; config.dynamicSmemBytes = dynamic_smem_in_bytes; @@ -723,10 +785,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - renormalize, routed_scaling_factor, scoring_func); + switch (sf) { + case SCORING_NONE: { + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); + break; + } + case SCORING_SIGMOID: { + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); + break; + } + default: + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ diff --git a/csrc/moe/marlin_moe_wna16/.gitignore b/csrc/moe/marlin_moe_wna16/.gitignore index 77088552b85b..ba805f9250ec 100644 --- a/csrc/moe/marlin_moe_wna16/.gitignore +++ b/csrc/moe/marlin_moe_wna16/.gitignore @@ -1 +1,2 @@ -kernel_*.cu \ No newline at end of file +sm*_kernel_*.cu +kernel_selector.h diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index be5b68cc53e6..88f1055337fd 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -4,134 +4,282 @@ import itertools import os import subprocess +import sys import jinja2 -FILE_HEAD = """ -// auto generated by generate.py +ARCHS = [] +SUPPORT_FP8 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True + +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py // clang-format off +""".lstrip() +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { -""".strip() +""" +) TEMPLATE = ( "template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " + "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" + "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) -# int8 with zero point case (vllm::kU8) is also supported, -# we don't add it to reduce wheel size. -SCALAR_TYPES = [ - "vllm::kU4", - "vllm::kU4B8", - "vllm::kU8B128", - "vllm::kFE4M3fn", - "vllm::kFE2M1f", -] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] -# group_blocks: -# = 0 : act order case -# = -1 : channelwise quantization -# > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] -DTYPES = ["fp16", "bf16"] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # AWQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): - all_template_str_list = [] + result_dict = {} - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS - ): - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", - "vllm::kU8B128", - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue - - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: - continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + if "16" in a_type and "16" in c_type and a_type != c_type: continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": "pipe_stages", + "group_blocks": group_blocks, + "is_zp_float": "false", + } - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + result_dict[(a_type, b_type, c_type)].append(config) - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" + kernel_selector_str = FILE_HEAD_COMMENT + for (a_type, b_type, c_type), config_list in result_dict.items(): + all_template_str_list = [] + for config in config_list: + s_type = config["s_type"] template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=False, + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, ) - all_template_str_list.append(template_str) + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) + + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " + + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) + + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) + file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " TORCH_CHECK(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + if __name__ == "__main__": remove_old_kernels() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 6190f7ee21ec..57f5a17932d4 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -11,8 +11,9 @@ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \ @@ -20,12 +21,13 @@ const float *__restrict__ topk_weights_ptr, int top_k, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ - bool use_fp32_reduce, int max_shared_mem + bool use_fp32_reduce namespace MARLIN_NAMESPACE_NAME { -template shared // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -76,8 +77,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce, // whether to use fp32 global reduce - int max_shared_mem) {} + bool use_fp32_reduce // whether to use fp32 global reduce +) {} } // namespace MARLIN_NAMESPACE_NAME @@ -85,65 +86,148 @@ __global__ void Marlin( // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } -template +template __device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); const uint32_t* b2 = reinterpret_cast(&frag_b2); float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200 + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + #else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + #endif + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -template -__device__ inline void ldsm(typename ScalarType::FragA& frag_a, +template +__device__ inline void ldsm(typename MarlinScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); @@ -167,47 +251,54 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +template +__device__ inline void scale(typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -template +template __device__ inline void scale_and_sub( - typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s2 = ScalarType::num2num2(s); - scalar_t2 zp2 = ScalarType::num2num2(zp); + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); } -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t2& frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; +template +__device__ inline void scale4( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s_1, + typename MarlinScalarType::FragS& frag_s_2, + typename MarlinScalarType::FragS& frag_s_3, + typename MarlinScalarType::FragS& frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; @@ -221,12 +312,13 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, } // Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { +template +__device__ inline void scale_float( + float* c, typename MarlinScalarType::FragS& s) { + using scalar_t = typename MarlinScalarType::scalar_t; scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -278,9 +370,10 @@ __device__ inline void wait_negative_and_add(int* lock) { __syncthreads(); } -template ; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; + #endif + + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; extern __shared__ int4 sh[]; - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); } else if constexpr (std::is_same::value) { @@ -355,34 +472,37 @@ __global__ void Marlin( static_assert(s_type == vllm::kFloat16); } - constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || - w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || + is_a_8bit || b_type == vllm::kFE4M3fn || + b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == vllm::kU8); + has_zp && !is_zp_float && !(b_type == vllm::kU8); - scalar_t2 global_scale; + c_scalar_t2 global_scale; constexpr bool has_act_order = group_blocks == 0; - constexpr int pack_factor = 32 / w_type.size_bits(); + constexpr int pack_factor = 32 / b_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); - constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; const int scales_expert_stride = - prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); + prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); const int b_bias_expert_stride = prob_n / 8; // parallel: num valid moe blocks - int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; int parallel = num_tokens_past_padded / moe_block_size; int num_valid_blocks = parallel; if (is_ep) { @@ -395,7 +515,23 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + // we use DP + two-tile SK here + // part1: DP + // part2: two-tile SK + // see https://github.com/vllm-project/vllm/pull/24722 for more details + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x; + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { @@ -407,14 +543,15 @@ __global__ void Marlin( } } - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = + k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; int par_id = 0; int block_id = -1; @@ -422,85 +559,89 @@ __global__ void Marlin( int old_expert_id = 0; int64_t B_expert_off = 0; - int4* sh_block_sorted_ids_int4 = sh; + float* sh_a_s = reinterpret_cast(sh); + int4* sh_block_sorted_ids_int4 = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); int4* sh_rd_block_sorted_ids_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4; int4* sh_block_topk_weights_int4 = sh_rd_block_sorted_ids_int4 + moe_block_size / 4; // sh_block_topk_weights_int4 only need (moe_block_size / 4); // but we pad to align to 256 bytes - int4* sh_new = - sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 2; int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); int32_t* sh_rd_block_sorted_ids = reinterpret_cast(sh_rd_block_sorted_ids_int4); - scalar_t2* sh_block_topk_weights = - reinterpret_cast(sh_block_topk_weights_int4); + c_scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); int32_t block_num_valid_tokens = 0; int32_t locks_off = 0; // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) { - // when parallel * n_tiles >= sms + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms // then there are at most $sms$ conflict tile blocks locks_off = blockIdx.x; } else { locks_off = (iters * blockIdx.x) / k_tiles - 1; } + int prob_m_top_k = prob_m * top_k; // read moe block data given block_id // block_sorted_ids / block_num_valid_tokens / block_topk_weights auto read_moe_block_data = [&](int block_id) { block_num_valid_tokens = moe_block_size; + + cp_async4_pred(sh_block_sorted_ids_int4 + threadIdx.x, + reinterpret_cast(sorted_token_ids_ptr) + + (block_id * moe_block_size / 4 + threadIdx.x), + threadIdx.x < moe_block_size / 4); + + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + if (threadIdx.x >= threads - 32) { + constexpr int size_per_thread = div_ceil(moe_block_size, 32); + int lane_id = threadIdx.x - (threads - 32); + + int local_count = 0; #pragma unroll - for (int i = 0; i < moe_block_size / 4; i++) { - int4 sorted_token_ids_int4 = reinterpret_cast( - sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; - int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); - #pragma unroll - for (int j = 0; j < 4; j++) { - if (sorted_token_ids[j] >= prob_m * top_k) { - block_num_valid_tokens = i * 4 + j; - break; + for (int i = 0; i < size_per_thread; i++) { + int j = lane_id * size_per_thread + i; + if (j < moe_block_size) { + int idx = sh_block_sorted_ids[j]; + if (idx < prob_m_top_k) local_count++; } } - if (block_num_valid_tokens != moe_block_size) break; - } - __syncthreads(); - int tid4 = threadIdx.x / 4; - if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { - sh_block_sorted_ids_int4[tid4] = reinterpret_cast( - sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); - #pragma unroll - for (int i = 0; i < 4; i++) - sh_rd_block_sorted_ids[tid4 * 4 + i] = - sh_block_sorted_ids[tid4 * 4 + i] / top_k; + if (lane_id == 0) + reinterpret_cast(sh_new)[0] = block_num_valid_tokens; + } + + if (threadIdx.x < moe_block_size) { + int idx = sh_block_sorted_ids[threadIdx.x]; + sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k; if (mul_topk_weights) { - #pragma unroll - for (int i = 0; i < 4; i++) { - int idx = tid4 * 4 + i; - idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - sh_block_topk_weights[idx] = __hmul2( - global_scale, Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = Dtype::num2num2( - Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); - } + idx = idx < prob_m_top_k ? idx : 0; + c_scalar_t2 topk_weight_val = + Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx])); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + topk_weight_val = __hmul2(topk_weight_val, global_scale); } + sh_block_topk_weights[threadIdx.x] = topk_weight_val; } } + + __syncthreads(); + + block_num_valid_tokens = reinterpret_cast(sh_new)[0]; __syncthreads(); }; @@ -511,9 +652,8 @@ __global__ void Marlin( old_expert_id = expert_id; if (num_invalid_blocks > 0) { - int skip_count = block_id == -1 ? par_id : 0; - block_id++; - for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + int skip_count = par_id; + for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) { expert_id = expert_ids_ptr[i]; if (expert_id != -1) { if (skip_count == 0) { @@ -528,9 +668,9 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = scale2_ptr[expert_id]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + uint16_t val = global_scale_ptr[expert_id]; + global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); } B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); @@ -550,10 +690,11 @@ __global__ void Marlin( // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) { + bool first_init = true; + auto init_part2_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; @@ -571,7 +712,7 @@ __global__ void Marlin( if (col_off > 0) slice_idx--; } } - if (parallel * n_tiles >= gridDim.x) { + if (part2_mn_tiles >= gridDim.x) { if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } @@ -605,25 +746,61 @@ __global__ void Marlin( par_id++; update_next_moe_block_data(); } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + cp_async1_ca_pred(&sh_a_s[threadIdx.x], + &a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]], + threadIdx.x < block_num_valid_tokens); + } }; - update_next_moe_block_data(); - init_slice(true); + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + update_next_moe_block_data(); + if (is_a_8bit) { + __syncthreads(); + cp_async1_ca_pred(&sh_a_s[threadIdx.x], + &a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]], + threadIdx.x < block_num_valid_tokens); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + update_next_moe_block_data(); + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; + int a_gl_stride = prob_k / (is_a_8bit ? 16 : 8); // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile @@ -632,24 +809,25 @@ __global__ void Marlin( constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = + ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_stage = + b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = + 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -662,7 +840,8 @@ __global__ void Marlin( constexpr int act_s_max_num_groups = 32; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; + + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides @@ -677,7 +856,6 @@ __global__ void Marlin( // Global A read index of current thread. int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; - // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -685,17 +863,22 @@ __global__ void Marlin( int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; + + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += B_expert_off + b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; @@ -706,58 +889,54 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / - (w_type == vllm::kFE2M1f ? 2 : 1) + + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; if constexpr (has_zp) { if constexpr (group_blocks == -1) { zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { + } else if constexpr (group_blocks >= thread_k_blocks) { zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; } } auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); } else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; int bias_sh_rd; if constexpr (m_block_size_8) { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; } else { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; } @@ -773,12 +952,16 @@ __global__ void Marlin( if constexpr (has_zp) { if constexpr (is_zp_float) { if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + zp_sh_rd = + 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % tb_n_warps / 2) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } else { zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } } @@ -805,18 +988,13 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; // Shared memory storage for global fetch pipelines. constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; @@ -845,19 +1023,12 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - constexpr int shm_size_used = moe_block_size + - stages * (g_idx_stage + zp_sh_stage) + - sh_s_size + sh_b_red_bias_size; - - // all remaining shared memory is used to cache A (input) - // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` - int sh_a_max_row = - ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; FragS frag_s[2][4]; // No act-order FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order @@ -865,6 +1036,24 @@ __global__ void Marlin( FragZP frag_zp; // Zero-points in fp16 FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + if constexpr (is_a_8bit && group_blocks != -1) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + // Zero accumulators. auto zero_accums = [&]() { #pragma unroll @@ -908,43 +1097,36 @@ __global__ void Marlin( } } }; - // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - bool should_load_a = true; - int max_num_stage_groups = - ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; - max_num_stage_groups = max(max_num_stage_groups, 1); - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, - int pipe_a = 0) { + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - if (should_load_a) { - int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; - int64_t sorted_row = 0; - if (!m_block_size_8 || row < 8) - sorted_row = sh_rd_block_sorted_ids[row]; - int64_t true_idx = - sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], - row < block_num_valid_tokens); - } + for (int i = 0; i < a_sh_wr_iters; i++) { + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = + sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], - B_ptr[i] + j + B_expert_off); - } + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = + b_gl_rd + (i % count) * threads + + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); - B_ptr[i] += b_gl_rd_delta_o; + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); } + b_gl_rd += b_gl_rd_delta_o; + if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; @@ -964,44 +1146,24 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -1035,18 +1197,18 @@ __global__ void Marlin( // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { - int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm( + ldsm( frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -1070,53 +1232,54 @@ __global__ void Marlin( auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = - reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } - } else { + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / tb_n_warps; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = - k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != vllm::kFE2M1f.id()) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { + } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } else { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + - k % 2]; + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } } } @@ -1137,18 +1300,15 @@ __global__ void Marlin( cur_k = 0; // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + cur_k += k % b_sh_wr_iters; // Determine "position" inside the thread-block (based on warp and // thread-id) auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; + cur_k += warp_row * 16 * b_sh_wr_iters; auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix @@ -1203,18 +1363,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 || is_a_8bit) { #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } - } else if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = @@ -1223,21 +1381,11 @@ __global__ void Marlin( } } else { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int warp_row = warp_id / tb_n_warps; - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1256,29 +1404,18 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + - zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } - } else { + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1289,33 +1426,46 @@ __global__ void Marlin( } }; - auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) { + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) return; int k2 = k % 2; + constexpr int g = + group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; const bool is_new_zp = - ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == 0) || + ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && + (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; int zp_quant_0, zp_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (b_type.size_bits() == 4) { zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = zp_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = frag_qzp[k2][1]; } - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); } } if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { @@ -1325,14 +1475,14 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales( - s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( - s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1343,61 +1493,168 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (b_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; - } else if constexpr (w_type.size_bits() == 4) { + } else if constexpr (b_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); } // Apply scale to frag_b0 - if constexpr (has_act_order) { + if constexpr (has_act_order && !is_a_8bit) { static_assert(group_blocks != -1); - scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && - group_blocks == -1) { + group_blocks == -1 && !is_a_8bit) { int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( + scalar_t2 s2 = Adtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && + !is_a_8bit) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } else if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; + #pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = + reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = + reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma(frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * + scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * + scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || + s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t* s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = + *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } } } } @@ -1411,7 +1668,8 @@ __global__ void Marlin( constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_stride = + b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); @@ -1426,7 +1684,8 @@ __global__ void Marlin( for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { @@ -1435,24 +1694,26 @@ __global__ void Marlin( float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); @@ -1468,13 +1729,13 @@ __global__ void Marlin( // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; if (!is_th_active) { return; } - int c_gl_stride = prob_n / 8; + int c_gl_stride = prob_n / 8 * (is_a_8bit ? 2 : 1); int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr; @@ -1485,7 +1746,7 @@ __global__ void Marlin( } else { c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); } constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; @@ -1504,7 +1765,13 @@ __global__ void Marlin( if (c_idx / c_gl_stride < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + if constexpr (is_a_8bit) { + int2* sh_red_int2 = reinterpret_cast(sh_red); + int2* c_int2 = reinterpret_cast(C); + sh_red_int2[c_sh_wr + c_sh_wr_delta * i] = c_int2[true_idx]; + } else { + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } } } } @@ -1512,29 +1779,37 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_scalar_t* c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = + reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + + delta] += Cdtype::num2float(c_red_f16[j]); } } if (!last) { - int4 c; + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + + delta]); } int c_idx; @@ -1547,7 +1822,12 @@ __global__ void Marlin( if (c_idx / c_gl_stride < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - C[true_idx] = c; + if constexpr (is_a_8bit) { + int2* c_int2 = reinterpret_cast(C); + c_int2[true_idx] = *reinterpret_cast(c_f16); + } else { + C[true_idx] = *reinterpret_cast(c_f16); + } } } } @@ -1561,10 +1841,10 @@ __global__ void Marlin( constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; constexpr int th_size = num_floats * sizeof(float) / 16; int c_cur_offset = locks_off * c_size; @@ -1632,7 +1912,7 @@ __global__ void Marlin( } else { c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); } int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + @@ -1641,49 +1921,49 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + c_scalar_t2 res = + Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && + b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - scalar_t2 tmp_scale = s[0]; + c_scalar_t2 tmp_scale = s[0]; if constexpr (m_block_size_8) { - tmp_scale = Dtype::num2num2( + tmp_scale = Cdtype::num2num2( reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); } res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if (!mul_topk_weights) { res = __hmul2(res, global_scale); } } if (has_bias && last) { - scalar_t2 tmp_bias = b_bias[0]; + c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { - tmp_bias = Dtype::num2num2( + tmp_bias = Cdtype::num2num2( reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); } res = __hadd2(res, tmp_bias); } if constexpr (m_block_size_8) { - ((scalar_t*)sh_red)[idx] = res.x; - ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + ((c_scalar_t*)sh_red)[idx] = res.x; + ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; } else { - ((scalar_t2*)sh_red)[idx] = res; + ((c_scalar_t2*)sh_red)[idx] = res; } }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) { + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], @@ -1721,24 +2001,26 @@ __global__ void Marlin( if (row < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[row]; int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; - scalar_t2 topk_weight_score; + c_scalar_t2 topk_weight_score; if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; if (use_atomic_add && slice_count > 1 || mul_topk_weights) { - scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); + c_scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + c_scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + if (mul_topk_weights) { #pragma unroll - for (int a = 0; a < 4; a++) { - scalar_t2 res = sh_red_half2[a]; - if (mul_topk_weights) { - res = __hmul2(res, topk_weight_score); + for (int a = 0; a < 4; a++) { + sh_red_half2[a] = __hmul2(sh_red_half2[a], topk_weight_score); } + } - if (use_atomic_add && slice_count > 1) { - atomicAdd(&C_half2[a], res); - } else { - C_half2[a] = res; - }; + if (use_atomic_add && slice_count > 1) { + #pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[true_idx] = *reinterpret_cast(sh_red_half2); } } else { C[true_idx] = sh_red[c_sh_rd]; @@ -1772,7 +2054,7 @@ __global__ void Marlin( } } } - fetch_to_shared(i, i, i < slice_iters, i); + fetch_to_shared(i, i, i < slice_iters); } zero_accums(); @@ -1797,73 +2079,100 @@ __global__ void Marlin( // have even length meaning that the next iteration will always start at // index 0. - for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; - stage_group_id++) { #pragma unroll - for (int pipe = 0; pipe < stages;) { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - int idx = - (pipe >= stages && stage_group_id == max_num_stage_groups - 1) - ? (pipe - stages) - : (pipe + stage_group_id * stages); - fetch_to_registers(k + 1, pipe % stages, idx); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) - ? (pipe - 1) - : (pipe + (stage_group_id + 1) * stages - 1); - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages, idx); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd_col += a_gl_rd_delta_o * stages; - if constexpr (has_act_order) { - slice_k_start += tb_k * stages; - - if (slice_k_start < prob_k) { - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, - last_group_id); - __syncthreads(); - } + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); } } + slice_iters--; if (slice_iters == 0) { break; } } + a_gl_rd_col += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } + } + // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1881,20 +2190,27 @@ __global__ void Marlin( } if constexpr (!has_act_order && group_blocks == -1 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; if constexpr (m_block_size_8) { int idx = (threadIdx.x / 4) % 2; - scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + c_scalar_t2* frag_s_half2 = + reinterpret_cast(frag_s); #pragma unroll for (int i = 0; i < 8; i++) { - frag_s_half2[i] = Dtype::num2num2( - reinterpret_cast(&frag_s_half2[i])[idx]); + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); } } } @@ -1904,26 +2220,48 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && + b_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); if constexpr (!m_block_size_8) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } @@ -1947,7 +2285,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; - reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + if constexpr (!is_a_8bit) + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; __syncthreads(); } @@ -1956,37 +2295,22 @@ __global__ void Marlin( if (last || use_atomic_add) // only the last block in a slice actually writes the result write_result(last); - int old_slice_row = slice_row; slice_row = 0; - slice_col_par++; - slice_col++; - is_first_matmul_in_slice = true; - init_slice(); - - // Should we load A matrix in next slice? - // `slice_col == 0`: when move to a new moe block - // `old_slice_row > 0`: - // when the last slice is not starting from k_index == 0 - // (only happen when it is the first slice of a threadblock) - // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: - // when the required shared memory size is larger than - // the remaining shared memory - if (slice_col == 0 || old_slice_row || - prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { - should_load_a = true; + if (!in_part2) { + slice_col_par += gridDim.x; } else { - should_load_a = false; + slice_col_par++; + slice_col++; } + is_first_matmul_in_slice = true; + init_slice(); if (slice_iters) { - a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } + a_gl_rd_col = + a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + b_gl_rd = B_expert_off + b_gl_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading @@ -1996,8 +2320,26 @@ __global__ void Marlin( slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } start_pipes(); } diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 601e2aa6f991..4fd8fc5c5420 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template -__global__ void permute_cols_kernel( - int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, - const int32_t* __restrict__ sorted_token_ids_ptr, - const int32_t* __restrict__ expert_ids_ptr, - const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, - int size_k, int top_k) {}; - -} // namespace marlin - -torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, - std::optional const& b_bias_or_none, torch::Tensor& b_scales, - std::optional const& b_zeros_or_none, - std::optional const& g_idx_or_none, - std::optional const& perm_or_none, torch::Tensor& workspace, - torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, - torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, - int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. template @@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, - int is_zp_float) { + int is_zp_float, bool is_a_8bit) { int pack_factor = 32 / num_bits; // Get B size @@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) - int sh_block_meta_size = tb_m * 4; - int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_block_meta_size = tb_m * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_bias_size = tb_n * 2; @@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, int is_zp_float, - int max_shared_mem) { + int max_shared_mem, bool is_a_8bit) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, } // Check that pipeline fits into cache - int cache_size = get_kernel_cache_size( - th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size + 512 <= max_shared_mem; + int cache_size = + get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit); + return cache_size <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ - } - - // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) - // this is the most common cases - // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) - // FZP: cases for float-zero-point (is_zp_float = true) - // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) - -template -MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, - int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool m_block_size_8, - bool has_act_order, bool has_zp, - int group_blocks, int num_threads, - bool is_zp_float) { - int num_bits = q_type.size_bits(); +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float) { + int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) - NVFP4_GET_IF(vllm::kFE2M1f) - - BIGGROUP_GET_IF(vllm::kFE4M3fn) - - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) - if (std::is_same::value) { - if (false) { - } - MXFP4_GET_IF(vllm::kFE2M1f) - } +#include "kernel_selector.h" return kernel; } -template -exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, - int prob_n, int prob_k, int thread_m_blocks, - bool m_block_size_8, int num_bits, - int group_size, bool has_act_order, - bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem) { +exec_config_t determine_exec_config( + const vllm::ScalarType& a_type, const vllm::ScalarType& b_type, + const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, + int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, int group_size, bool has_act_order, + bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms, + bool is_a_8bit) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem)) { + is_k_full, has_zp, is_zp_float, max_shared_mem - 512, + is_a_8bit)) { continue; } int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + is_a_8bit); int group_blocks = 0; if (!has_act_order) { group_blocks = group_size == -1 ? -1 : (group_size / 16); } - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, th_config.thread_n / 16, - th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + auto kernel = + get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float); if (kernel == MarlinDefault) continue; - if (thread_m_blocks > 1) { - exec_cfg = {1, th_config}; - break; - } else { - cudaFuncAttributes attr; - cudaFuncGetAttributes(&attr, kernel); - int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; - int allow_count = min(device_max_reg_size / reg_size, - max_shared_mem / (cache_size + 1024)); + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1536)); + if (thread_m_blocks == 1) allow_count = max(min(allow_count, 4), 1); - if (allow_count > count) { - count = allow_count; - exec_cfg = {count, th_config}; - }; + else + allow_count = max(min(allow_count, 2), 1); + + if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) { + allow_count = + max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1); } + + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; } return exec_cfg; } -template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, - void* s, void* s2, void* zp, void* g_idx, void* perm, - void* a_tmp, void* sorted_token_ids, void* expert_ids, - void* num_tokens_past_padded, void* topk_weights, - int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_bias, - bool has_act_order, bool is_k_full, bool has_zp, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { + void* a_s, void* b_s, void* g_s, void* zp, void* g_idx, + void* perm, void* a_tmp, void* sorted_token_ids, + void* expert_ids, void* num_tokens_past_padded, + void* topk_weights, int moe_block_size, int num_experts, + int top_k, bool mul_topk_weights, bool is_ep, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& a_type, vllm::ScalarType const& b_type, + vllm::ScalarType const& c_type, vllm::ScalarType const& s_type, + bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int blocks_per_sm, + bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; - - if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); - } + bool is_a_8bit = a_type.size_bits() == 8; TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } } - int num_bits = q_type.size_bits(); + int num_bits = b_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* bias_ptr = (const int4*)b_bias; - const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* a_s_ptr = (const float*)a_s; + const int4* b_s_ptr = (const int4*)b_s; + const uint16_t* g_s_ptr = (const uint16_t*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + TORCH_CHECK(major_capability * 10 + minor_capability >= 80, + "marlin kernel only support Ampere or newer GPUs."); + if (a_type == vllm::kFE4M3fn) { + TORCH_CHECK(major_capability * 10 + minor_capability >= 89, + "FP8 only support Ada Lovelace or newer GPUs."); + TORCH_CHECK( + major_capability * 10 + minor_capability == 89 || + major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + // Set thread config exec_config_t exec_cfg; thread_config_t thread_tfg; if (thread_k != -1 && thread_n != -1) { - thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; - exec_cfg = exec_config_t{1, thread_tfg}; + thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64}; + if (blocks_per_sm == -1) blocks_per_sm = 1; + exec_cfg = exec_config_t{blocks_per_sm, thread_tfg}; TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); } else { // Auto config - exec_cfg = determine_exec_config( - q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem); + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts, + top_k, thread_m_blocks, m_block_size_8, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms, + is_a_8bit); thread_tfg = exec_cfg.tb_cfg; } @@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; - TORCH_CHECK( - is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem), - "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - ", thread_k = ", thread_tfg.thread_k, - ", thread_n = ", thread_tfg.thread_n, - ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", - prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", has_act_order = ", has_act_order, - ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, - ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); - - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, - has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem, is_a_8bit), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem = ", max_shared_mem); + + int sh_cache_size = + get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit); + + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce); // clang-format on } } // namespace MARLIN_NAMESPACE_NAME torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& a_scales_or_none, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); + bool is_zp_float, int64_t thread_k, int64_t thread_n, + int64_t blocks_per_sm) { + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + auto c_dtype = a.dtype(); + if (a.scalar_type() == at::ScalarType::Half) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + c_dtype = b_scales.dtype(); + if (b_scales.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + + TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4"); + torch::Tensor c = c_or_none.value(); + c_dtype = c.dtype(); + + if (c.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (c.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + TORCH_CHECK(false, "unsupported c dtype"); + } + } + + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + a_type_id = vllm::kFE4M3fn.id(); + } else if (a.scalar_type() == at::ScalarType::Char) { + a_type_id = vllm::kS8.id(); + } else { + TORCH_CHECK(false, "unsupported `a` scalar_type"); + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + s_type_id = vllm::kFE4M3fn.id(); + } else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); + } else { + TORCH_CHECK(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); + int num_experts = b_q_weight.size(0); if (moe_block_size != 8) { TORCH_CHECK(moe_block_size % 16 == 0, @@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; + torch::Tensor a_scales; + auto options = torch::TensorOptions().dtype(c_dtype).device(a.device()); + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + + if (a_scales_or_none.has_value()) { + a_scales = a_scales_or_none.value(); + TORCH_CHECK(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + a_scales = torch::empty({0}, options_fp32); + TORCH_CHECK(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + // sms: number of SMs to use for the kernel int sms = -1; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; if (c_or_none.has_value()) { c = c_or_none.value(); @@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm( // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - auto options_fp32 = - torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce && !use_atomic_add) { // max num of threadblocks is sms * 4 long max_c_tmp_size = min( @@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm( bool has_zp = b_zeros.size(-1) > 0; if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); } else { - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " - "float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); } if (has_zp && is_zp_float) { @@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - - MARLIN_NAMESPACE_NAME::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), - sorted_token_ids.data_ptr(), expert_ids.data_ptr(), - num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), - moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - MARLIN_NAMESPACE_NAME::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - sorted_token_ids.data_ptr(), expert_ids.data_ptr(), - num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), - moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, - "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, + "scalar type of a_scales must be float"); + TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), + "scalar type of global_scale must be the same with c"); + if (a_type.size_bits() == 16) { + TORCH_CHECK( + a.scalar_type() == c.scalar_type(), + "scalar type of a must be the same with c for 16 bit activation"); } + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, num_experts, top_k, + mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(), + a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce, + is_zp_float); + return c; } -#endif - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); -} +} \ No newline at end of file diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index b3d0c0aa58e9..5c9e47402408 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -14,7 +14,6 @@ namespace vllm { namespace moe { - namespace batched_moe_align_block_size { // Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. @@ -80,17 +79,32 @@ __global__ void batched_moe_align_block_size_kernel( } // namespace batched_moe_align_block_size template -__global__ void moe_align_block_size_kernel( +__device__ void _moe_align_block_size( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) { + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id, + int32_t topk_num, int32_t* token_mask, bool has_expert_map) { extern __shared__ int32_t shared_counts[]; - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[it] = numel; + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + int cumsum_offset = (num_experts + 1) * model_offset; + + // Use separate threadblocks to fill sorted_token_ids. + // This is safe since the current kernel does not use sorted_token_ids. + if (blockIdx.x % 2) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += blockDim.x) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + return; } const int warp_id = threadIdx.x / WARP_SIZE; @@ -112,9 +126,16 @@ __global__ void moe_align_block_size_kernel( if (expert_id >= num_experts) { continue; } + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], + mask); } __syncthreads(); @@ -135,48 +156,196 @@ __global__ void moe_align_block_size_kernel( int cumsum_val; BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); if (expert_id <= num_experts) { - cumsum[expert_id] = cumsum_val; + cumsum[cumsum_offset + expert_id] = cumsum_val; } if (expert_id == num_experts) { - *total_tokens_post_pad = cumsum_val; + total_tokens_post_pad[model_offset] = cumsum_val; } __syncthreads(); if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + for (int i = cumsum[cumsum_offset + threadIdx.x]; + i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; } } // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; - const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); - for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { - expert_ids[i] = 0; + const size_t fill_start_idx = + cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } +} + +template +__device__ void _moe_align_block_size_small_batch_expert( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks, + int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num, + int32_t* token_mask, bool has_expert_map) { + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + + // Use an additional group of threads to fill sorted_token_ids. + // Since the current kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + tokens_cnts[(tid + 1) * num_experts + expert_id] += mask; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; + } + total_tokens_post_pad[model_offset] = + static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = tid; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + + if (token_mask == nullptr || token_mask[i / topk_num]) { + sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } } } template -__global__ void count_and_sort_expert_tokens_kernel( +__device__ void _count_and_sort_expert_tokens( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel, int32_t num_experts) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask, + int32_t model_offset, int32_t topk_num, bool has_expert_map) { + const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.y; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; if (expert_id >= num_experts) { continue; } - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; + + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } + + if (token_mask == nullptr || token_mask[i / topk_num]) { + int32_t rank_post_pad = atomicAdd( + &cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = + i; + } } } +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t topk_num, bool has_expert_map) { + _moe_align_block_size( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), + 0, 0, topk_num, nullptr, has_expert_map); +} + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { + _count_and_sort_expert_tokens( + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map); +} + template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] @@ -193,78 +362,111 @@ __global__ void moe_sum_kernel( } } -template +template __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t max_num_tokens_padded) { - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[it] = numel; - } - - const size_t tid = threadIdx.x; - const size_t stride = blockDim.x; - - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; - int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t topk_num, + bool has_expert_map) { + _moe_align_block_size_small_batch_expert( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, + has_expert_map); +} - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; +template +__global__ void moe_lora_align_block_size_kernel( + scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping, + int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, int32_t experts_per_warp, + int32_t padded_num_experts, int32_t* lora_ids, + int32_t* __restrict__ token_mask, bool has_expert_map) { + int lora_idx = blockIdx.x / 2; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; } - for (size_t i = tid; i < numel; i += stride) { - ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; + // Populate the token_mask based on the token-LoRA mapping + int num_tokens = numel / topk_num; + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + + for (int i = 0; i < num_tokens; i++) { + token_mask[(lora_id * num_tokens) + i] = + (int)token_lora_mapping[i] == lora_id; + } } __syncthreads(); - if (threadIdx.x < num_experts) { - tokens_cnts[threadIdx.x] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[i * num_experts + threadIdx.x] += - tokens_cnts[(i - 1) * num_experts + threadIdx.x]; - } + _moe_align_block_size( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num, + &token_mask[(lora_id * num_tokens)], has_expert_map); +} + +template +__global__ void lora_count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask, + int32_t* lora_ids, bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1) { + return; } - __syncthreads(); + int num_tokens = numel / topk_num; - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = - cumsum[i - 1] + - CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * - block_size; - } - *total_tokens_post_pad = static_cast(cumsum[num_experts]); + _count_and_sort_expert_tokens( + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id, + topk_num, has_expert_map); +} + +template +__global__ void moe_lora_align_block_size_small_batch_expert_kernel( + scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, + int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids, + int32_t* token_mask, bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; } - __syncthreads(); + int num_tokens = numel / topk_num; + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + for (int i = 0; i < num_tokens; i++) { + token_mask[(lora_id * num_tokens) + i] = + (int)token_lora_mapping[i] == lora_id; } } - // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; - const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); - for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { - expert_ids[i] = 0; - } + __syncthreads(); - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = - tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[threadIdx.x * num_experts + expert_id]; - } + _moe_align_block_size_small_batch_expert( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks, + -1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)], + has_expert_map); } } // namespace moe @@ -275,7 +477,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t padded_num_experts = @@ -287,14 +490,19 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, // BlockScan uses 1024 threads and assigns one thread per expert. TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + bool has_expert_map = maybe_expert_map.has_value(); + torch::Tensor expert_map; + if (has_expert_map) { + expert_map = maybe_expert_map.value(); + } else { + expert_map = torch::empty({0}, options_int); + } VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -304,43 +512,58 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; auto small_batch_expert_kernel = vllm::moe::moe_align_block_size_small_batch_expert_kernel< - scalar_t>; - small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( + scalar_t, fill_threads>; + small_batch_expert_kernel<<<1, fill_threads + threads, + shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), sorted_token_ids.size(0)); + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, block_size, + topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1), + has_expert_map); } else { + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); auto align_kernel = vllm::moe::moe_align_block_size_kernel; size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); - align_kernel<<<1, threads, shared_mem_size, stream>>>( + // launch two threadblocks + // blockIdx.x == 0: counting experts and aligning + // blockIdx.x == 1: filling sorted_token_ids + align_kernel<<<2, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, - padded_num_experts, experts_per_warp, block_size, - topk_ids.numel(), cumsum_buffer.data_ptr(), - sorted_token_ids.size(0)); + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, padded_num_experts, + experts_per_warp, block_size, topk_ids.numel(), + cumsum_buffer.data_ptr(), sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int max_blocks = 65535; const int actual_blocks = std::min(num_blocks, max_blocks); + dim3 gridDims(1, actual_blocks); auto sort_kernel = vllm::moe::count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( + sort_kernel<<>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel(), num_experts); + cumsum_buffer.data_ptr(), expert_map.data_ptr(), + topk_ids.numel(), num_experts, sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); } }); } @@ -414,3 +637,123 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] break; } } + +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids, std::optional maybe_expert_map) { + const int topk_num = topk_ids.size(1); + + TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); + + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor token_mask = + torch::empty({max_loras * topk_ids.size(0)}, options_int); + bool has_expert_map = maybe_expert_map.has_value(); + torch::Tensor expert_map; + if (has_expert_map) { + expert_map = maybe_expert_map.value(); + } else { + expert_map = torch::empty({0}, options_int); + } + + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = max((int32_t)num_experts, 128); + const int32_t shared_mem = + (num_thread + 1) * num_experts * sizeof(int32_t) + + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, "Shared memory usage exceeds device limit."); + } + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = + vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel< + scalar_t, fill_threads>; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + // cumsum buffer + torch::Tensor cumsum = + torch::zeros({max_loras * (num_experts + 1)}, options_int); + + auto align_kernel = + vllm::moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + align_kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), cumsum.data_ptr(), + WARP_SIZE, padded_num_experts, lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = + vllm::moe::lora_count_and_sort_expert_tokens_kernel; + + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), cumsum.data_ptr(), + expert_map.data_ptr(), topk_ids.numel(), num_experts, + max_num_tokens_padded, topk_num, token_mask.data_ptr(), + lora_ids.data_ptr(), has_expert_map); + } + }); +} \ No newline at end of file diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu deleted file mode 100644 index 360f1312cf57..000000000000 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ /dev/null @@ -1,174 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "../cuda_compat.h" -#include "../dispatch_utils.h" -#include "core/math.hpp" - -namespace { - -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) { - return row * total_col + col; -} - -} // namespace - -// TODO: Refactor common parts with moe_align_sum_kernels -template -__global__ void moe_lora_align_sum_kernel( - scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, - int64_t block_size, int num_experts, int max_loras, size_t numel, - int max_num_tokens_padded, int max_num_m_blocks, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled, - int32_t* lora_ids) { - const size_t tokens_per_thread = div_ceil(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - int lora_idx = blockIdx.x; - int lora_id = lora_ids[lora_idx]; - if (lora_id == -1 || adapter_enabled[lora_id] == 0) { - return; - } - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; - token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); - - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel; - } - - // Initialize expert_ids with -1 - for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) { - expert_ids[lora_id * max_num_m_blocks + it] = -1; - } - - // Initialize total_tokens_post_pad with 0 - if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; - } - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int mask = token_lora_mapping[i / topk_num] == lora_id; - int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]); - tokens_cnts[idx] += mask; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - total_tokens_post_pad[lora_id] = static_cast(cumsum[num_experts]); - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] = - threadIdx.x; - } - } - - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - - int mask = (int)token_lora_mapping[i / topk_num] == lora_id; - atomicAdd( - &sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)], - (i - numel) * mask); - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask; - } -} - -void moe_lora_align_block_size( - torch::Tensor topk_ids, torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, int64_t max_loras, - int64_t max_num_tokens_padded, int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, - torch::Tensor lora_ids) { - const int topk_num = topk_ids.size(1); - - TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); - - int device_max_shared_mem; - auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE, - TORCH_CHECK(num_thread <= 1024, - "num_thread must be less than 1024, " - "and fallback is not implemented yet."); - const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + - (num_experts + 1) * sizeof(int32_t); - - if (shared_mem > device_max_shared_mem) { - TORCH_CHECK(false, - "Shared memory usage exceeds device limit, and global memory " - "fallback is not implemented yet."); - } - - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { - dim3 blockDim(num_thread); - auto kernel = moe_lora_align_sum_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<>>( - topk_ids.data_ptr(), - token_lora_mapping.data_ptr(), block_size, num_experts, - max_loras, topk_ids.numel(), max_num_tokens_padded, - max_num_m_blocks, sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), lora_ids.data_ptr()); - }); -} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 11c6875f7f1d..337dcc50b079 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map); void batched_moe_align_block_size(int64_t max_tokens_per_batch, int64_t block_size, @@ -26,7 +27,7 @@ void moe_lora_align_block_size( int64_t max_num_tokens_padded, int64_t max_num_m_blocks, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, - torch::Tensor lora_ids); + torch::Tensor lora_ids, std::optional maybe_expert_map); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index bd95ade40a08..779ad70ad1e0 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "moe_align_block_size(Tensor topk_ids, int num_experts," " int block_size, Tensor! sorted_token_ids," " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); + " Tensor! num_tokens_post_pad," + " Tensor? maybe_expert_map) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); // Aligning the number of tokens to be processed by each expert such @@ -46,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor !experts_ids," " Tensor !num_tokens_post_pad," " Tensor !adapter_enabled," - " Tensor !lora_ids) -> () "); + " Tensor !lora_ids," + " Tensor? maybe_expert_map) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); #ifndef USE_ROCM @@ -63,16 +65,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none," - "Tensor! b_scales, Tensor? global_scale, Tensor? " + "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? " "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! topk_weights, int moe_block_size, int top_k, " - "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "bool mul_topk_weights, bool is_ep, int b_type_id," "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," - "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + "bool use_fp32_reduce, bool is_zp_float," + "int thread_k, int thread_n, int blocks_per_sm) -> Tensor"); + m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " diff --git a/csrc/ops.h b/csrc/ops.h index f8bdc61aaa8e..37e3aaf7499d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,14 +52,13 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); -#ifndef USE_ROCM void merge_attn_states(torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); - +#ifndef USE_ROCM void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] @@ -103,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); -void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1); +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK); void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, - const torch::Tensor& seq_lens, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1); + const torch::Tensor& seqLens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, @@ -129,6 +131,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional scale_ub, std::optional residual); +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); @@ -253,7 +262,8 @@ void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_problem_sizes( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets); + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt); void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, @@ -300,6 +310,14 @@ void per_token_group_quant_int8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max); + +// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales. +void per_token_group_quant_8bit_packed(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit); + #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh new file mode 100644 index 000000000000..fec142d0d87a --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh @@ -0,0 +1,104 @@ +// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh +#pragma once + +#include +#include +#include + +#include "core/scalar_type.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +// ElementB is int32 (packed int4) +// ElementGroupScale is cutlass::Array (packed fp8) +template +__global__ void get_group_gemm_starts( + int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, + ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int, + ElementB* b_base_as_int, ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + ElementAccumulator* b_scales_base_as_int, + ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k, + int64_t scale_k) { + int expert_id = threadIdx.x; + + int64_t expert_offset = expert_offsets[expert_id]; + + // same as w8a8 + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset; + b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id); + + // w4a8 specific + constexpr int pack_factor = 8; // pack 8 int4 into int32 + b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor); + b_group_scales_offsets[expert_id] = + b_group_scales_base_as_int + (expert_id * scale_k * n); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts> \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast**>( \ + b_group_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast*>( \ + b_group_scales.data_ptr()), \ + n, k, scale_k); \ + } + +namespace { + +void run_get_group_gemm_starts( + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor& out_tensors, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& b_group_scales, const int64_t b_group_size) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32 + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_group_scales.dtype() == + torch::kFloat8_e4m3fn); // the underlying torch type is e4m3 + TORCH_CHECK(out_tensors.dtype() == + torch::kBFloat16); // only support bf16 for now + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); + + int num_experts = static_cast(expert_offsets.size(0)); + // logical k, n + int64_t n = out_tensors.size(1); + int64_t k = a_tensors.size(1); + int64_t scale_k = cutlass::ceil_div(k, b_group_size); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu new file mode 100644 index 000000000000..4b425790dbac --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -0,0 +1,483 @@ +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +// vllm includes +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" +#include "cutlass_extensions/common.hpp" + +#include "core/registration.h" +#include "get_group_starts.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "w4a8_utils.cuh" + +namespace vllm::cutlass_w4a8_moe { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // per + // group +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; + +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr PackFactor = 8; // 8 int4 packed into int32 + +// A matrix configuration +using ElementA = MmaType; +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = + 128 / + cutlass::sizeof_bits::value; // Alignment of A matrix in units of + // elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input +// layouts +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; + +// Need to pass a pointer type to make the 3rd dimension of Stride be _0 +using StrideA = + cute::remove_pointer_t>; +using StrideB = + cute::remove_pointer_t>; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout>, StrideB>{})); + +using ElementScale = cutlass::float_e4m3_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + +// per-channel and per-token scales for epilogue +using ElementSChannel = float; + +template +struct W4A8GroupedGemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // per-channel, per-token scales epilogue + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogueArray; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementSChannel, ElementC, + typename cutlass::layout::LayoutTranspose::type*, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type*, + AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + // =========================================================== MIXED INPUT + // WITH SCALES + // =========================================================================== + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>; + + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::InternalStrideC; + using StrideD = typename GemmKernelShuffled::InternalStrideD; + + using StrideC_ref = cutlass::detail::TagToStrideC_t; + using StrideD_ref = cutlass::detail::TagToStrideC_t; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + using StrideS_ref = cutlass::detail::TagToStrideB_t; + + // static asserts for passing in strides/layouts + // pack to 2x int64 + static_assert(sizeof(StrideS) == 2 * sizeof(int64_t)); + // pack to 3xint32, + static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, + "LayoutB_Reordered size must be divisible by 4 bytes"); + + static void grouped_mm( + torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides) { + auto device = a_tensors.device(); + auto device_id = device.index(); + const at::cuda::OptionalCUDAGuard device_guard(device); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + int num_experts = static_cast(expert_offsets.size(0)); + int n = static_cast(b_tensors.size(1)); + int k = static_cast(b_tensors.size(2)) * PackFactor; + + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(device); + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int); + + // get the correct offsets to pass to gemm + run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs, + a_tensors, b_tensors, out_tensors, a_scales, + b_scales, b_group_scales, b_group_size); + + // construct args + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + Args arguments; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast( + problem_sizes_torch.data_ptr()); + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; + + // SwapAB so B operands come first + MainloopArguments mainloop_arguments{ + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast**>( + b_group_scales_ptrs.data_ptr()), + static_cast(group_scale_strides.data_ptr()), + static_cast(b_group_size)}; + + EpilogueArguments epilogue_arguments{ + // since we are doing SwapAB the channel scales comes first, then token + // scales + ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray + static_cast( + b_scales_ptrs.data_ptr()), // per-channel + static_cast( + a_scales_ptrs.data_ptr()), // per-token + true, true), + nullptr, // C + static_cast(c_strides.data_ptr()), // C + static_cast(out_ptrs.data_ptr()), // D + static_cast(c_strides.data_ptr()) // D + }; + + static const cutlass::KernelHardwareInfo hw_info{ + device_id, + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; + + arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, + mainloop_arguments, epilogue_arguments, hw_info}; + + // Allocate workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + +// Kernel_TileShape_ClusterShape_Schedule +using Kernel_128x16_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_128x16_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x16_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x16_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x32_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x32_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x64_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x64_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x128_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x128_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_128x256_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +void mm_dispatch( + torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides, const std::string& schedule) { + if (schedule == "Kernel_128x16_1x1x1_Coop") { + Kernel_128x16_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_128x16_2x1x1_Coop") { + Kernel_128x16_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x16_1x1x1_Coop") { + Kernel_256x16_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x16_2x1x1_Coop") { + Kernel_256x16_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x32_1x1x1_Coop") { + Kernel_256x32_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x32_2x1x1_Coop") { + Kernel_256x32_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x64_1x1x1_Coop") { + Kernel_256x64_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x64_2x1x1_Coop") { + Kernel_256x64_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x128_1x1x1_Coop") { + Kernel_256x128_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x128_2x1x1_Coop") { + Kernel_256x128_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_128x256_2x1x1_Coop") { + Kernel_128x256_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else { + TORCH_CHECK(false, + "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); + } +} + +void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides, + std::optional maybe_schedule) { + // user has specified a schedule + if (maybe_schedule) { + mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + b_group_scales, b_group_size, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides, group_scale_strides, + *maybe_schedule); + return; + } + + // use heuristic + int m_full = a_tensors.size(0); + int n = b_tensors.size(1); + int k = b_tensors.size(2) * PackFactor; // logical k + int num_experts = b_tensors.size(0); + // per-expert batch size assuming uniform distribution + int m_expert = m_full / num_experts; + + std::string schedule; + if (m_expert <= 16) { + schedule = "Kernel_128x16_2x1x1_Coop"; + } else if (m_expert <= 32) { + schedule = "Kernel_256x32_1x1x1_Coop"; + } else if (m_expert <= 64) { + schedule = "Kernel_256x64_1x1x1_Coop"; + } else if (m_expert <= 128) { + schedule = "Kernel_256x128_2x1x1_Coop"; + } else { // m_expert > 128 + schedule = "Kernel_128x256_2x1x1_Coop"; + } + + mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + b_group_scales, b_group_size, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides, group_scale_strides, schedule); +} + +std::tuple encode_and_reorder_int4b( + torch::Tensor const& b_tensors) { + TORCH_CHECK(b_tensors.dtype() == torch::kInt32); + TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) + TORCH_CHECK(b_tensors.is_contiguous()); + TORCH_CHECK(b_tensors.is_cuda()); + + int n = static_cast(b_tensors.size(1)); + int k = static_cast(b_tensors.size(2)) * PackFactor; // logical k + + // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. + // These misalignments cause silent OOB unless run under Compute Sanitizer. + TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); + TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); + + // we will store the layout to an int32 tensor; + // this is the number of elements we need per layout + constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); + + torch::Tensor b_tensors_packed = torch::empty_like(b_tensors); + int num_experts = static_cast(b_tensors.size(0)); + + auto b_ptr = static_cast(b_tensors.const_data_ptr()); + auto b_packed_ptr = static_cast(b_tensors_packed.data_ptr()); + + // multiply by ull so result does not overflow int32 + size_t num_int4_elems = 1ull * num_experts * n * k; + bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, + num_int4_elems); + TORCH_CHECK(ok, "unified_encode_int4b failed"); + + // construct the layout once; assumes each expert has the same layout + using LayoutType = LayoutB_Reordered; + std::vector layout_B_reordered_host(num_experts); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}}); + auto shape_B = cute::make_shape(n, k, Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B); + LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + + // reorder weights for each expert + for (int i = 0; i < num_experts; i++) { + // since the storage type of int4b is 1 byte but one element is 4 bits + // we need to adjust the offset + int64_t offset = + 1ull * i * n * k * cutlass::sizeof_bits::value / 8; + cutlass::reorder_tensor(b_packed_ptr + offset, layout_B, + layout_B_reordered); + } + + // save the packed layout to torch tensor so we can re-use it + auto cpu_opts = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + torch::Tensor layout_cpu = + torch::empty({num_experts, layout_width}, cpu_opts); + + int32_t* layout_data = layout_cpu.data_ptr(); + for (int i = 0; i < num_experts; ++i) { + std::memcpy(layout_data + i * layout_width, // dst (int32*) + &layout_B_reordered, // src (LayoutType*) + sizeof(LayoutType)); // number of bytes + } + + torch::Tensor packed_layout = + layout_cpu.to(b_tensors.device(), /*non_blocking=*/false); + + return {b_tensors_packed, packed_layout}; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_moe_mm", &mm); + m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8_moe +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu index 2d1568b08651..f77af06cd6c0 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -7,6 +7,7 @@ #include #include #include "cutlass_extensions/torch_utils.hpp" +#include "w4a8_utils.cuh" #include "core/registration.h" @@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { return packed_scales; } -/* - GPU-accelerated implementation of cutlass::unified_encode_int4b. - Constructs a lookup table in constant memory to map 8 bits - (two 4-bit values) at a time. Assumes memory is contiguous - and pointers are 16-byte aligned. -*/ -__constant__ uint8_t kNibbleLUT[256]; - -__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, - size_t nbytes) { - constexpr size_t V = sizeof(uint4); // 16 bytes - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t nthreads = size_t(gridDim.x) * blockDim.x; - const size_t nvec = nbytes / V; - - // 1-D grid-stride loop over 16-byte chunks - for (size_t vec = tid; vec < nvec; vec += nthreads) { - uint4 v = reinterpret_cast(in)[vec]; - uint8_t* b = reinterpret_cast(&v); -#pragma unroll - for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; - reinterpret_cast(out)[vec] = v; - } -} - -static bool upload_lut() { - std::array lut{}; - auto map_nib = [](uint8_t v) -> uint8_t { - // 1..7 -> (8 - v); keep 0 and 8..15 - return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); - }; - for (int b = 0; b < 256; ++b) { - uint8_t lo = b & 0xF; - uint8_t hi = (b >> 4) & 0xF; - lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); - } - cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), - /*offset=*/0, cudaMemcpyHostToDevice); - - return (e == cudaSuccess); -} - -static bool unified_encode_int4b(cutlass::int4b_t const* in, - cutlass::int4b_t* out, size_t num_int4_elems) { - // Build/upload LUT - if (!upload_lut()) return false; - - static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, - "int4 storage must be 1 byte"); - const size_t nbytes = num_int4_elems >> 1; - - auto* in_bytes = reinterpret_cast(in); - auto* out_bytes = reinterpret_cast(out); - - // kernel launch params - constexpr int block = 256; - const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors - int grid = int((nvec + block - 1) / block); - if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel - - unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); - cudaError_t err = cudaGetLastError(); - return (err == cudaSuccess); -} - torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dim() == 2); @@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { LayoutB_Reordered layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); - bool ok = - vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr, + n * k); TORCH_CHECK(ok, "unified_encode_int4b failed"); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cu b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu new file mode 100644 index 000000000000..f238d0a5b2d7 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu @@ -0,0 +1,90 @@ +#include "w4a8_utils.cuh" + +#include +#include +#include + +namespace vllm::cutlass_w4a8_utils { + +/* + GPU-accelerated implementation of cutlass::unified_encode_int4b. + Constructs a lookup table in constant memory to map 8 bits + (two 4-bit values) at a time. Assumes memory is contiguous + and pointers are 16-byte aligned. +*/ +__constant__ uint8_t kNibbleLUT[256]; + +__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, + size_t nbytes) { + constexpr size_t V = sizeof(uint4); // 16 bytes + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t nthreads = size_t(gridDim.x) * blockDim.x; + const size_t nvec = nbytes / V; + + // 1-D grid-stride loop over 16-byte chunks + for (size_t vec = tid; vec < nvec; vec += nthreads) { + uint4 v = reinterpret_cast(in)[vec]; + uint8_t* b = reinterpret_cast(&v); +#pragma unroll + for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; + reinterpret_cast(out)[vec] = v; + } +} + +static bool upload_lut() { + std::array lut{}; + auto map_nib = [](uint8_t v) -> uint8_t { + // 1..7 -> (8 - v); keep 0 and 8..15 + return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); + }; + for (int b = 0; b < 256; ++b) { + uint8_t lo = b & 0xF; + uint8_t hi = (b >> 4) & 0xF; + lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); + } + cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), + /*offset=*/0, cudaMemcpyHostToDevice); + + return (e == cudaSuccess); +} + +bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, + size_t num_int4_elems) { + // Build/upload LUT + if (!upload_lut()) return false; + + static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, + "int4 storage must be 1 byte"); + const size_t nbytes = num_int4_elems >> 1; + + auto* in_bytes = reinterpret_cast(in); + auto* out_bytes = reinterpret_cast(out); + + // kernel launch params + constexpr int block = 256; + const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors + int grid = int((nvec + block - 1) / block); + if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel + + unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); + + // launch errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("unified_encode_int4b_device launch error: %s (%d)\n", + cudaGetErrorString(err), err); + return false; + } + + // runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf("unified_encode_int4b_device runtime error: %s (%d)\n", + cudaGetErrorString(err), err); + return false; + } + + return true; +} + +} // namespace vllm::cutlass_w4a8_utils \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh b/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh new file mode 100644 index 000000000000..25090091a368 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh @@ -0,0 +1,11 @@ +#pragma once + +#include +#include "cutlass/numeric_types.h" + +namespace vllm::cutlass_w4a8_utils { + +bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, + size_t num_int4_elems); + +} // namespace vllm::cutlass_w4a8_utils \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 5b007e5ea328..674440278383 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -22,6 +22,7 @@ #include #include #include +#include "cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" @@ -173,7 +174,7 @@ void run_get_group_gemm_starts( } template -void run_fp4_blockwise_scaled_group_mm( +void run_fp4_blockwise_scaled_group_mm_sm100( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, @@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm( auto can_implement_status = gemm_op.can_implement(args); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); + "Failed to implement GEMM: status=", (int)can_implement_status); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, " num_experts=", num_experts, + " M=", M, " N=", N, " K=", K); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +void run_fp4_blockwise_scaled_group_mm_sm120( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + // NOTE: For SM120 it seems templating the output type is not supported and + // we need to hardcode the output type to bfloat16 + using ElementC = cutlass::bfloat16_t; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + + using FusionOperation = cutlass::epilogue::fusion::LinearCombination< + ElementD, ElementAccumulator, ElementC, ElementAccumulator>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutD*, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, + LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + fusion_args.beta = 0.0f; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: status=", (int)can_implement_status); // Run the GEMM auto status = gemm_op.initialize(args, workspace.data_ptr()); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, " num_experts=", num_experts, + " M=", M, " N=", N, " K=", K); status = gemm_op.run(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + if (version_num >= 120 && version_num < 130) { + run_fp4_blockwise_scaled_group_mm_sm120( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + return; + } +#endif #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 + if (version_num >= 100 && version_num < 120) { + run_fp4_blockwise_scaled_group_mm_sm100( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + return; + } +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ", + version_num, ". Required capability: 100 or 120"); +} + +#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ + (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; #endif @@ -374,7 +583,8 @@ void cutlass_fp4_group_mm( const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 +#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ + (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) // Input validation CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); @@ -408,6 +618,14 @@ void cutlass_fp4_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); } else { + #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + int32_t version_num = get_sm_version_num(); + if (version_num >= 120 && version_num < 130) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ", + output.scalar_type()); + } + #endif run_fp4_blockwise_scaled_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); @@ -416,8 +634,8 @@ void cutlass_fp4_group_mm( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_fp4_group_mm kernel, vLLM must " - "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA " - "12.8 or above."); + "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 " + "and CUDA 12.8 or above."); #endif } diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 6d385e0dd94e..82c53c2375a3 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float; constexpr auto INT = at::ScalarType::Int; constexpr auto UINT8 = at::ScalarType::Byte; -void scaled_fp4_experts_quant_sm100a( +void scaled_fp4_experts_quant_sm1xxa( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index c2b39e543880..fb6d22f035b9 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input_sf); #endif -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 -void scaled_fp4_experts_quant_sm100a( +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void scaled_fp4_experts_quant_sm1xxa( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, @@ -54,8 +55,9 @@ void scaled_fp4_experts_quant( torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return scaled_fp4_experts_quant_sm100a( +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return scaled_fp4_experts_quant_sm1xxa( output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); #endif diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index 9cba2828aac2..d9c4d24d8e1f 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -15,6 +15,8 @@ */ #include +#include +#include "cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, @@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& alpha); #endif -void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); -#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 - return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); +void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& A_sf, + const torch::Tensor& B_sf, + const torch::Tensor& alpha) { + // Make sure we’re on A’s device. + const c10::cuda::OptionalCUDAGuard device_guard(device_of(A)); + const int32_t sm = get_sm_version_num(); + +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (sm >= 100 && sm < 120) { + cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); + return; + } +#endif + +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (sm >= 120 && sm < 130) { + cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); + return; + } #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No compiled nvfp4 mm kernel, vLLM should " - "be compiled using CUDA 12.8 and target " - "compute capability 100 or above."); + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm, + ". Recompile with CUDA >= 12.8 and CC >= 100."); } bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { int runtimeVersion; cudaRuntimeGetVersion(&runtimeVersion); return cuda_device_capability >= 100 && runtimeVersion >= 12080; -} \ No newline at end of file +} diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 92d6c2f402a2..2080ef3cd39b 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::vectorized::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } @@ -75,14 +76,52 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } + +// RMS norm + quant kernel +template +__global__ void rms_norm_per_block_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + // or + // [hidden_size / group_size, num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms; + // Compute RMS + // Always able to vectorize due to constraints on hidden_size + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute Scale + // Always able to vectorize due to constraints on hidden_size and group_size + vllm::vectorized::compute_dynamic_per_token_scales< + scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( + nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual); + + // RMS Norm + Quant + // Always able to vectorize due to constraints on hidden_size + // For int8, don't invert token_scale here: do it inside the norm_and_quant + // kernel. We do it because particular elements of token_scale can be shared + // between multiple threads, so this way, we avoid extra synchronization + // overhead. + vllm::vectorized::norm_and_quant< + scalar_t, scalar_out_t, std::is_same_v, + has_residual, is_scale_transposed, group_size>( + out, input, weight, rms, scales, hidden_size, residual); +} + } // namespace vllm // Residual add + RMS norm + dynamic per token @@ -103,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (residual.has_value()) { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { vllm::rms_norm_dynamic_per_token_quant_kernel + has_residual> <<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr()); + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() : nullptr); }); - - } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { - vllm::rms_norm_dynamic_per_token_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr); - }); - } + }); } void rms_norm_dynamic_per_token_quant( @@ -157,3 +185,79 @@ void rms_norm_dynamic_per_token_quant( out, input, weight, scales, var_epsilon, scale_ub, residual); }); } + +// Residual add + RMS norm + dynamic per token +void rms_norm_per_block_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or + // [hidden_size / group_size, num_tokens] + int32_t group_size, + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual, bool is_scale_transposed) { + int32_t hidden_size = input.size(-1); + auto num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + const int max_block_size = (num_tokens <= 256) ? 512 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] { + using scalar_in_t = scalar_t; + VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() + : nullptr); + }); + }); + }); + }); + }); +} + +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const var_epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(weight.dtype() == input.dtype()); + TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } + + TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); + + rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, + var_epsilon, scale_ub, residual, + is_scale_transposed); +} \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 2d2fd771205c..cb7adc312573 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -9,6 +9,7 @@ #include "quant_conversions.cuh" #include "../../cub_helpers.h" +#include "../../cuda_compat.h" namespace vllm { @@ -43,62 +44,150 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, *rms = s_rms; } -template +__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, + int64_t thread_in_warp, + int64_t reduced_elems) { + static_assert(WARP_SIZE == 32 || WARP_SIZE == 64); + if constexpr (WARP_SIZE == 64) { + if (thread_in_warp + 64 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 64]); + } + if (thread_in_warp + 32 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 32]); + if (thread_in_warp + 16 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 16]); + if (thread_in_warp + 8 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 8]); + if (thread_in_warp + 4 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 4]); + if (thread_in_warp + 2 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 2]); + if (thread_in_warp + 1 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 1]); + return val[tid]; +} + +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, - scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; - constexpr scalar_out_t qmax{quant_type_max_v}; - + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { float block_absmax_val_maybe = 0.0f; - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); - if constexpr (has_residual) { - x += static_cast(residual[token_offset + i]); + constexpr scalar_out_t qmax{quant_type_max_v}; + __syncthreads(); + if (group_size > 0) { + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = threadIdx.x / threads_per_group * group_size; + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = + min(group_offset + group_size, static_cast(hidden_size)); + for (auto i = thread_offset; i < thread_end; i += threads_per_group) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); + + int64_t const warp_size = WARP_SIZE; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } + } + __syncthreads(); + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + + for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } - x = static_cast(static_cast(x * rms) * weight[i]); - block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); - } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store - } - __syncthreads(); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); - *token_scale = s_token_scale; + *token_scale = s_token_scale; + } } template + bool has_residual = false, bool is_scale_transposed = false> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, - float const rms, float const scale, + float const rms, float* const scale, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + scalar_t* __restrict__ residual = nullptr, + int32_t const group_size = 0) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -109,8 +198,21 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, // Norm x = static_cast(static_cast(x * rms) * weight[i]); // Quant + // If groupwise is_scale_inverted is true, so we invert the scale here. + int64_t scale_idx = 0; + if (group_size > 0) { + if constexpr (is_scale_transposed) { + scale_idx = (i / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size; + } + } + auto scale_val = + (group_size > 0 + ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]) + : *scale); output[token_offset + i] = - ScaledQuant::quant_fn(x, scale); + ScaledQuant::quant_fn(x, scale_val); } } @@ -178,95 +280,191 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // Vectorized version of vllm::compute_dynamic_per_token_scales // hidden_size must be a multiple of 4 -template +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; - - // Vectorized input/weight/residual to better utilize memory bandwidth. - vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); - vec4_t const* vec_weight = - reinterpret_cast const*>(weight); - vec4_t const* vec_residual = nullptr; - if constexpr (has_residual) { - vec_residual = - reinterpret_cast const*>(&residual[token_offset]); - } - constexpr scalar_out_t qmax{quant_type_max_v}; const int VEC_SIZE = 4; - int32_t const num_vec_elems = hidden_size >> 2; float block_absmax_val_maybe = 0.0f; + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = nullptr; + vec4_t const* vec_weight = nullptr; + vec4_t const* vec_residual = nullptr; + + if constexpr (group_size > 0) { + __shared__ float s_max_vals[1024]; + + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t const num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = + threadIdx.x / threads_per_group * (group_size >> 2); + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = min(group_offset + (group_size >> 2), + static_cast(hidden_size >> 2)); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + int32_t const num_vec_elems = thread_end; + #pragma unroll 4 - for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { - vec4_t in = vec_input[i]; - vec4_t const w = vec_weight[i]; + for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; - vec4_t x; + vec4_t x; #pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] = static_cast(in.val[j]); + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); + } + + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); + + int64_t const warp_size = WARP_SIZE; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } + } + __syncthreads(); + + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { - vec4_t r = vec_residual[i]; + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = (hidden_size >> 2); + +#pragma unroll 4 + for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] += static_cast(r.val[j]); + x.val[j] = static_cast(in.val[j]); } - } + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; #pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - block_absmax_val_maybe = - fmaxf(block_absmax_val_maybe, - fabs(static_cast(x.val[j] * rms) * w.val[j])); - } - } + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } + } - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // shared memory store - all_token_scales[blockIdx.x] = scale; // global output store - } - __syncthreads(); + __syncthreads(); - *token_scale = s_token_scale; + *token_scale = s_token_scale; + } } // hidden_size must be a multiple of 4 template + bool has_residual = false, bool is_scale_transposed = false, + int32_t group_size = 0> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, - float const rms, float const scale, + float const rms, float* const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = @@ -311,10 +509,26 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, } q8x4_t out; + + float scale_val; + + if constexpr (group_size > 0) { + int64_t const num_groups = hidden_size / group_size; + int64_t scale_idx = 0; + if constexpr (is_scale_transposed) { + scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; + } + scale_val = + is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]; + } else { + scale_val = *scale; + } #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { out.val[j] = ScaledQuant::quant_fn( - static_cast(x.val[j] * rms) * w.val[j], scale); + static_cast(x.val[j] * rms) * w.val[j], scale_val); } vec_output[i] = out; } diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index 03bd5964a7fc..e306ff02605b 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { #pragma unroll for (int k_idx = 0; k_idx < 2; ++k_idx) { - FType low16 = - ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2]); - FType high16 = - ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); + FType low16 = MarlinScalarType2::float2num( + C_frag[m_idx][n_idx][k_idx * 2]); + FType high16 = MarlinScalarType2::float2num( + C_frag[m_idx][n_idx][k_idx * 2 + 1]); uint32_t tmp = (reinterpret_cast(low16) & 0xffff) | (reinterpret_cast(high16) << 16); int sts_offset = diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh index 831413016538..14a61ad8fd88 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -8,7 +8,7 @@ #include #include #include "../gptq_marlin/marlin_dtypes.cuh" -using marlin::ScalarType; +using marlin::MarlinScalarType2; namespace allspark { @@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; for (int i = 0; i < n_mat; ++i) { - sum += ScalarType::num2float(C_split[idx + i * matrix_size]); + sum += MarlinScalarType2::num2float(C_split[idx + i * matrix_size]); } - C[idx] = ScalarType::float2num(sum); + C[idx] = MarlinScalarType2::float2num(sum); } template diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/gptq_marlin/.gitignore index 77088552b85b..ba805f9250ec 100644 --- a/csrc/quantization/gptq_marlin/.gitignore +++ b/csrc/quantization/gptq_marlin/.gitignore @@ -1 +1,2 @@ -kernel_*.cu \ No newline at end of file +sm*_kernel_*.cu +kernel_selector.h diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index e607107b3e77..307bae6738ec 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -4,14 +4,16 @@ namespace marlin { -template +template __global__ void awq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; @@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel( extern __shared__ int4 sh[]; - constexpr int tile_n_ints = tile_n_size / pack_factor; + constexpr int tile_n_ints = target_tile_n_size / pack_factor; constexpr int stage_n_threads = tile_n_ints / 4; - constexpr int stage_k_threads = tile_k_size; + constexpr int stage_k_threads = target_tile_k_size; constexpr int stage_size = stage_k_threads * stage_n_threads; auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { @@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel( return; } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * target_tile_n_size; int first_n_packed = first_n / pack_factor; int4* sh_ptr = sh + stage_size * pipe; @@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel( auto k_id = threadIdx.x / stage_n_threads; auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * target_tile_k_size; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast( @@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel( } int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; int cur_n_packed = cur_n / pack_factor; int cur_n_pos = cur_n % pack_factor; @@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel( uint32_t vals[8]; #pragma unroll for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; - - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + - sh_stride * cur_elem]; - - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + if constexpr (is_a_8bit) { + int cur_elem = tc_row + i; + + int packed_src_0 = + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + + sh_stride * cur_elem]; + int packed_src_1 = + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + + sh_stride * (cur_elem + 16)]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } else { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = + sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + constexpr int tile_size = + target_tile_k_size * target_tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; uint32_t res = 0; #pragma unroll @@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel( uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; @@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel( } // namespace marlin -#define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - marlin::awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::awq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, out_ptr, size_k, size_n); \ +#define CALL_IF(NUM_BITS, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ } torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits) { + int64_t size_n, int64_t num_bits, + bool is_a_8bit) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", marlin::tile_k_size); @@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, if (false) { } - CALL_IF(4) - CALL_IF(8) + CALL_IF(4, false) + CALL_IF(8, false) + CALL_IF(4, true) + CALL_IF(8, true) else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", is_a_8bit = ", is_a_8bit); } return out; diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index e8b0c302b202..26b8d40368aa 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,6 +470,50 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>( + int q, __nv_fp8x4_e4m3* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4; + constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70707070; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + + // Note1: reverse indexing is intentional because weights are permuted + // Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of + // `frag_b[1]` to fit the layout of tensorcore + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, int32_t* frag_b) { + constexpr int repeated_zp = 0x08080808; + constexpr int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>( + int q, __nv_fp8x4_e4m3* frag_b) { + int s = q & 0x08080808; + int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + q >>= 4; + s = q & 0x08080808; + int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); +} + template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); @@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales( // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); +}; + +// subtract zero point in quanted format and then dequant +template +__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp); + +template <> +__device__ inline void sub_zp_and_dequant( + int q, int32_t* frag_b, int zp) { + // INT4 with zp -> INT8 + // see https://github.com/vllm-project/vllm/pull/24722 + int repeated_zp = 0x01010101 * zp; + int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(), + true>(int q, __nv_fp8x4_e4m3* frag_b, + int zp) { + // INT4 with zp -> FP8 + // see https://github.com/vllm-project/vllm/pull/24722 + uint32_t u_q = *reinterpret_cast(&q); + uint32_t u_zp = *reinterpret_cast(&zp); + uint32_t u_zp1 = u_zp + 1; + uint32_t repeated_zp = 0x01010101 * u_zp; + + uint32_t q0, s; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + u_q >>= 4; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); } #endif diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 42d3b456096e..27ef7271ba41 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -4,141 +4,292 @@ import itertools import os import subprocess +import sys import jinja2 -FILE_HEAD = """ -// auto generated by generate.py +ARCHS = [] +SUPPORT_FP8 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True + +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py // clang-format off +""".lstrip() +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { -""".strip() +""" +) TEMPLATE = ( "template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " + "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" + "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) -# int8 with zero point case (vllm::kU8) is also supported, -# we don't add it to reduce wheel size. -SCALAR_TYPES = [ - "vllm::kU4", - "vllm::kU4B8", - "vllm::kU8B128", - "vllm::kFE4M3fn", - "vllm::kFE2M1f", -] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] -# group_blocks: -# = 0 : act order case -# = -1 : channelwise quantization -# > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] -DTYPES = ["fp16", "bf16"] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # HQQ + { + "a_type": ["kFloat16"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [4], + "is_zp_float": True, + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # GPTQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): - all_template_str_list = [] + result_dict = {} - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS - ): - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", - "vllm::kU8B128", - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue - - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: - continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + is_zp_float = quant_config.get("is_zp_float", False) + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + if "16" in a_type and "16" in c_type and a_type != c_type: continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] - - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" - - is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: - # HQQ (is_zp_float = true) only supports - # 4bit quantization and fp16 - is_zp_float_list.append(True) - - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" - - for is_zp_float in is_zp_float_list: - template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=is_zp_float, - ) + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue - all_template_str_list.append(template_str) + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": "pipe_stages", + "group_blocks": group_blocks, + "is_zp_float": "true" if is_zp_float else "false", + } + + result_dict[(a_type, b_type, c_type)].append(config) + + kernel_selector_str = FILE_HEAD_COMMENT + + for (a_type, b_type, c_type), config_list in result_dict.items(): + all_template_str_list = [] + for config in config_list: + s_type = config["s_type"] + template_str = jinja2.Template(TEMPLATE).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + all_template_str_list.append(template_str) + + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) + + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " + + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) + + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " TORCH_CHECK(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + if __name__ == "__main__": remove_old_kernels() diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index cc30abcf0080..28ff06559a98 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm( std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size + 512 <= max_shared_mem; + return cache_size <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ - } - - // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) - // this is the most common cases - // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) - // FZP: cases for float-zero-point (is_zp_float = true) - // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 4, 8, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - -template -MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, - int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool m_block_size_8, - bool has_act_order, bool has_zp, - int group_blocks, int num_threads, - bool is_zp_float) { - int num_bits = q_type.size_bits(); +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float) { + int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) - NVFP4_GET_IF(vllm::kFE2M1f) - - BIGGROUP_GET_IF(vllm::kFE4M3fn) - - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) - - if (std::is_same::value) { - if (false) { - } - FZP_GET_IF(vllm::kU4) - } - if (std::is_same::value) { - if (false) { - } - MXFP4_GET_IF(vllm::kFE2M1f) - } + #include "kernel_selector.h" return kernel; } -template -exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, - int prob_n, int prob_k, int thread_m_blocks, - bool m_block_size_8, int num_bits, - int group_size, bool has_act_order, - bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem, - int sms) { +exec_config_t determine_exec_config( + const vllm::ScalarType& a_type, const vllm::ScalarType& b_type, + const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8, + int num_bits, int group_size, bool has_act_order, bool is_k_full, + bool has_zp, bool is_zp_float, int max_shared_mem, int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem)) { + is_zp_float, max_shared_mem - 512)) { continue; } @@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, group_blocks = group_size == -1 ? -1 : group_size / 16; } - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, th_config.thread_n / 16, - th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + auto kernel = + get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float); if (kernel == MarlinDefault) continue; @@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, return exec_cfg; } -template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, - void* s, void* s2, void* zp, void* g_idx, void* perm, - void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, - void* workspace, vllm::ScalarType const& q_type, bool has_bias, + void* a_s, void* b_s, void* g_s, void* zp, void* g_idx, + void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, + int lda, void* workspace, vllm::ScalarType const& a_type, + vllm::ScalarType const& b_type, vllm::ScalarType const& c_type, + vllm::ScalarType const& s_type, bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { - if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); - } - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } } - int num_bits = q_type.size_bits(); + int num_bits = b_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; - const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* a_s_ptr = (const float*)a_s; + const int4* b_s_ptr = (const int4*)b_s; + const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; - int* locks = (int*)workspace; if (has_act_order) { @@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + TORCH_CHECK(major_capability * 10 + minor_capability >= 80, + "marlin kernel only support Ampere or newer GPUs."); + if (a_type == vllm::kFE4M3fn) { + TORCH_CHECK( + major_capability * 10 + minor_capability == 89 || + major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + int max_par = 16; if (prob_n <= 4096) max_par = 16 * 8; int max_shared_mem_new = max_shared_mem; @@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int thread_n = thread_n_init; int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); - int m_block_size_8 = prob_m_split <= 8; + int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16; // Set thread config exec_config_t exec_cfg; @@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, " is not divisible by thread_k = ", thread_k); } else { // Auto config - exec_cfg = determine_exec_config( - q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem, sms); + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k, + thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem, sms); thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_n != -1) { + if (prob_n / thread_tfg.thread_n * + div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <= + sms) { + if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split, + prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem_new)) { + thread_tfg = {128, 64, 128}; + exec_cfg = {1, thread_tfg}; + } + } + } + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { max_thread_m_blocks--; continue; @@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", max_shared_mem_new = ", max_shared_mem_new); - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, - m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, - is_zp_float); + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, num_groups, prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on - A_ptr += prob_m_split * (lda / 8); + bool is_a_8bit = a_type.size_bits() == 8; + A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8)); + a_s_ptr += prob_m_split; C_ptr += prob_m_split * (prob_n / 8); rest_m -= prob_m_split; } @@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& a_scales_or_none, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + auto c_dtype = a.dtype(); + if (a.scalar_type() == at::ScalarType::Half) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + c_dtype = b_scales.dtype(); + if (b_scales.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + + TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4"); + torch::Tensor c = c_or_none.value(); + c_dtype = c.dtype(); + + if (c.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (c.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + TORCH_CHECK(false, "unsupported c dtype"); + } + } + + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + a_type_id = vllm::kFE4M3fn.id(); + } else if (a.scalar_type() == at::ScalarType::Char) { + a_type_id = vllm::kS8.id(); + } else { + TORCH_CHECK(false, "unsupported `a` scalar_type"); + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + s_type_id = vllm::kFE4M3fn.id(); + } else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); + } else { + TORCH_CHECK(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm( TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + torch::Tensor a_scales; + auto options = torch::TensorOptions().dtype(c_dtype).device(a.device()); + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + + if (a_scales_or_none.has_value()) { + a_scales = a_scales_or_none.value(); + TORCH_CHECK(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + a_scales = torch::empty({0}, options_fp32); + TORCH_CHECK(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; @@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm( // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; if (c_or_none.has_value()) { c = c_or_none.value(); @@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm( // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - auto options_fp32 = - torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce) { int max_m_block_size = (size_m + 16 - 1) / 16 * 16; max_m_block_size = min(max_m_block_size, 64); @@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm( bool has_zp = b_zeros.size(-1) > 0; if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); } else { - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " - "float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); } if (has_zp && is_zp_float) { @@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - marlin::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, - is_k_full, has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - - marlin::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, - has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, + "scalar type of a_scales must be float"); + TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), + "scalar type of global_scale must be the same with c"); + if (a_type.size_bits() == 16) { + TORCH_CHECK( + a.scalar_type() == c.scalar_type(), + "scalar type of a must be the same with c for 16 bit activation"); } + marlin::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), + workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + return c; } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index ad80d51ece94..796e6c5359da 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -4,15 +4,18 @@ namespace marlin { -template +template __global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; @@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel( extern __shared__ int4 sh[]; - constexpr int perm_size = tile_k_size / 4; + constexpr int perm_size = target_tile_k_size / 4; int4* sh_perm_ptr = sh; int4* sh_pipe_ptr = sh_perm_ptr; @@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel( sh_pipe_ptr += perm_size; } - constexpr int tile_ints = tile_k_size / pack_factor; + constexpr int tile_ints = target_tile_k_size / pack_factor; - constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_n_threads = target_tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; + int first_k_int4 = (k_tile_id * target_tile_k_size) / 4; int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); @@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel( return; } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * target_tile_n_size; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; @@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel( auto k_id = threadIdx.x / stage_n_threads; auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * target_tile_k_size; int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], @@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel( } int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; - constexpr int sh_stride = 64; + constexpr int sh_stride = target_tile_n_size; constexpr uint32_t mask = (1 << num_bits) - 1; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; @@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel( uint32_t vals[8]; if constexpr (has_perm) { + static_assert(!is_a_8bit); for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; @@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel( #pragma unroll for (int i = 0; i < tile_ints; i++) { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + if constexpr (is_a_8bit) { + b1_vals[i] = + sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8]; + } else { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } } #pragma unroll for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; + int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]); int cur_int = cur_elem / pack_factor; int cur_pos = cur_elem % pack_factor; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + if constexpr (is_a_8bit) + vals[4 + i] = + (b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask; + else + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + constexpr int tile_size = + target_tile_k_size * target_tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; uint32_t res = 0; #pragma unroll @@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel( uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; @@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel( } // namespace marlin -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ +#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \ + is_a_8bit == IS_A_8BIT) { \ cudaFuncSetAttribute( \ marlin::gptq_marlin_repack_kernel, \ + HAS_PERM, IS_A_8BIT>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ marlin::gptq_marlin_repack_kernel \ + HAS_PERM, IS_A_8BIT> \ <<>>( \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, - int64_t num_bits) { + int64_t num_bits, bool is_a_8bit) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", marlin::tile_k_size); @@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, if (false) { } - CALL_IF(4, false) - CALL_IF(4, true) - CALL_IF(8, false) - CALL_IF(8, true) + CALL_IF(4, false, false) + CALL_IF(4, true, false) + CALL_IF(8, false, false) + CALL_IF(8, true, false) + + CALL_IF(4, false, true) + CALL_IF(8, false, true) + else { TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, - ", has_perm = ", has_perm); + ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit); } return out; diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index bb454f6aff22..b3b79c8aec45 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -11,17 +11,19 @@ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ int max_shared_mem namespace MARLIN_NAMESPACE_NAME { -template (__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index cc1605481434..a4807a6887f8 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -2,8 +2,10 @@ #ifndef _data_types_cuh #define _data_types_cuh #include "marlin.cuh" +#include "core/scalar_type.hpp" #include #include +#include #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin @@ -11,14 +13,16 @@ namespace MARLIN_NAMESPACE_NAME { -template -class ScalarType {}; +template +class MarlinScalarType {}; template <> -class ScalarType { +class MarlinScalarType { public: using scalar_t = half; using scalar_t2 = half2; + using scalar_t4 = half2; + using scalar_32bit_t = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: @@ -27,6 +31,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; using FragZP = Vec; static __device__ float inline num2float(const half x) { @@ -44,18 +49,25 @@ class ScalarType { static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } }; template <> -class ScalarType { +class MarlinScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; + using scalar_t4 = nv_bfloat162; + using scalar_32bit_t = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; using FragZP = Vec; #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 @@ -75,9 +87,63 @@ class ScalarType { static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } #endif }; +template <> +class MarlinScalarType { + public: + using scalar_t = __nv_fp8_e4m3; + using scalar_t2 = __nv_fp8x2_e4m3; + using scalar_t4 = __nv_fp8x4_e4m3; + using scalar_32bit_t = __nv_fp8x4_e4m3; + + using FragA = Vec<__nv_fp8x4_e4m3, 4>; + using FragB = Vec<__nv_fp8x4_e4m3, 2>; + using FragC = Vec; + using FragZP = Vec<__nv_fp8x2_e4m3, 4>; + + static __host__ __device__ + float2 inline num22float2(const __nv_fp8x2_e4m3 x) { + return (float2)x; + } +}; + +template <> +class MarlinScalarType { + public: + using scalar_t = int8_t; + using scalar_t2 = int16_t; + using scalar_t4 = int32_t; + using scalar_32bit_t = int32_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragZP = Vec; +}; + +template +class MarlinScalarType2 {}; + +template <> +class MarlinScalarType2 : public MarlinScalarType {}; + +template <> +class MarlinScalarType2 + : public MarlinScalarType {}; + +template <> +class MarlinScalarType2<__nv_fp8_e4m3> + : public MarlinScalarType {}; + +template <> +class MarlinScalarType2 : public MarlinScalarType {}; + } // namespace MARLIN_NAMESPACE_NAME #endif diff --git a/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu b/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu new file mode 100644 index 000000000000..7d4c97fb57ed --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu @@ -0,0 +1,106 @@ + + +#include "marlin.cuh" + +#include "core/registration.h" + +// for only non-zp format (like gptq) +__global__ void marlin_int4_fp8_preprocess_kernel_without_zp( + // qweight: (size_k * size_n // 8,) + const int32_t* __restrict__ qweight, + // output: same shape with qweight + int32_t* __restrict__ output) { + int32_t val = qweight[blockIdx.x * 32 + threadIdx.x]; + int32_t new_val = 0; + +#pragma unroll + for (int32_t i = 0; i < 8; i++) { + int32_t single_val = val & 0xF; + single_val = single_val >= 8 ? single_val - 8 : 15 - single_val; + new_val |= single_val << (i * 4); + val >>= 4; + } + + output[blockIdx.x * 32 + threadIdx.x] = new_val; +} + +// for awq format only (with zp and with awq weight layout) +__global__ void marlin_int4_fp8_preprocess_kernel_awq( + // AWQ qweight: (size_k, size_n // 8) + const int32_t* __restrict__ qweight, + // output: same shape with qweight + int32_t* __restrict__ output, + // AWQ zeros: (size_k // group_size, size_n // 8) + const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k, + int32_t group_size) { + int32_t val = + qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y]; + int32_t zero = + qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 + + blockIdx.y]; + int32_t new_val = 0; + +#pragma unroll + for (int32_t i = 0; i < 8; i++) { + int32_t single_val = val & 0xF; + int32_t single_zero = zero & 0xF; + + single_val = + single_val >= single_zero ? single_val - single_zero : 15 - single_val; + new_val |= single_val << (i * 4); + val >>= 4; + zero >>= 4; + } + + output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val; +} + +torch::Tensor marlin_int4_fp8_preprocess( + torch::Tensor& qweight, std::optional qzeros_or_none, + bool inplace) { + TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU"); + TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int, + "qweight.dtype != torch.int32"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); + + torch::Tensor output = inplace ? qweight : torch::empty_like(qweight); + + if (!qzeros_or_none.has_value()) { + TORCH_CHECK(qweight.numel() * 8 % 256 == 0, + "qweight.numel() * 8 % 256 != 0"); + + int blocks = qweight.numel() * 8 / 256; + marlin_int4_fp8_preprocess_kernel_without_zp<<>>( + (const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr()); + } else { + int32_t size_k = qweight.size(0); + int32_t size_n = qweight.size(1) * 8; + torch::Tensor qzeros = qzeros_or_none.value(); + + TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0"); + TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU"); + TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int, + "qweight.dtype != torch.int32"); + TORCH_CHECK(device_of(qweight) == device_of(qzeros), + "qzeros is not on the same device with qweight"); + + int32_t group_size = qweight.size(0) / qzeros.size(0); + TORCH_CHECK(qweight.size(1) == qzeros.size(1), + "qweight.size(1) != qzeros.size(1)"); + TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0, + "qweight.size(0) % qzeros.size(0) != 0"); + TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0"); + + dim3 blocks(size_k / 32, size_n / 8); + marlin_int4_fp8_preprocess_kernel_awq<<>>( + (const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(), + (const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size); + } + + return output; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess); +} diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index bfb0a3668f52..22bb71e482ce 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } -template +template __device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); const uint32_t* b2 = reinterpret_cast(&frag_b2); float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -template -__device__ inline void ldsm(typename ScalarType::FragA& frag_a, +template +__device__ inline void ldsm(typename MarlinScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); @@ -159,47 +233,54 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +template +__device__ inline void scale(typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -template +template __device__ inline void scale_and_sub( - typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s2 = ScalarType::num2num2(s); - scalar_t2 zp2 = ScalarType::num2num2(zp); + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); } -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t2& frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; +template +__device__ inline void scale4( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s_1, + typename MarlinScalarType::FragS& frag_s_2, + typename MarlinScalarType::FragS& frag_s_3, + typename MarlinScalarType::FragS& frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; @@ -213,12 +294,13 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, } // Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { +template +__device__ inline void scale_float( + float* c, typename MarlinScalarType::FragS& s) { + using scalar_t = typename MarlinScalarType::scalar_t; scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -270,9 +352,10 @@ __device__ inline void wait_negative_and_add(int* lock) { __syncthreads(); } -template __global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ A0, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C0, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ b_bias_ptr, - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 - // only) - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + // float scales of input matrix, only used when is_a_8bit == true. + // shape (m,) + const float* __restrict__ a_scales_ptr, + // fp16 quantization scales. shape (k/groupsize, n) + const int4* __restrict__ scales_ptr, + // fp16 global scale (for nvfp4// only) + const uint16_t* __restrict__ global_scale_ptr, + // 4bit packed zero-points of shape + // (k/groupsize, n/pack_factor) + const int4* __restrict__ zp_ptr, + // int32 group indices of shape k + const int* __restrict__ g_idx, int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -321,17 +409,35 @@ __global__ void Marlin( // ensures good utilization of all SMs for many kinds of shape and GPU // configurations, while requiring as few slow global cross-threadblock // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; + #endif + + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + const int4* A = A0; + int4* C = C0; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; + + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); } else if constexpr (std::is_same::value) { @@ -340,27 +446,35 @@ __global__ void Marlin( static_assert(s_type == vllm::kFloat16); } - constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || - w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || + is_a_8bit || b_type == vllm::kFE4M3fn || + b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == vllm::kU8); + has_zp && !is_zp_float && !(b_type == vllm::kU8); + + c_scalar_t2 global_scale; - scalar_t2 global_scale; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - // NVFP4 format requires global scale - uint16_t val = scale2_ptr[0]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + uint16_t val = global_scale_ptr[0]; + global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); } constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); - constexpr int pack_factor = 32 / w_type.size_bits(); + extern __shared__ int4 sh[]; + float* sh_a_s = reinterpret_cast(sh); + int4* sh_new = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); + constexpr int pack_factor = 32 / b_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); // For larger GEMMs we run multiple batchsize 64 versions in parallel for a @@ -373,7 +487,19 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x; + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { @@ -385,28 +511,21 @@ __global__ void Marlin( } } - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = + k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; int par_id = 0; int locks_off = 0; - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) { - // when parallel * n_tiles >= sms + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms // then there are at most $sms$ conflict tile blocks locks_off = blockIdx.x; } else { @@ -415,10 +534,11 @@ __global__ void Marlin( // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) { + bool first_init = true; + auto init_part2_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; @@ -436,7 +556,7 @@ __global__ void Marlin( if (col_off > 0) slice_idx--; } } - if (parallel * n_tiles >= gridDim.x) { + if (part2_mn_tiles >= gridDim.x) { if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } @@ -466,28 +586,68 @@ __global__ void Marlin( } if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * lda / 8; + A += 16 * thread_m_blocks * lda / (is_a_8bit ? 16 : 8); C += 16 * thread_m_blocks * prob_n / 8; slice_col = 0; par_id++; } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } }; - init_slice(true); + + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + if (is_a_8bit) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = lda / 8; + int a_gl_stride = lda / (is_a_8bit ? 16 : 8); // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile @@ -496,24 +656,25 @@ __global__ void Marlin( constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = + ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_stage = + b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = + 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -527,7 +688,7 @@ __global__ void Marlin( int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides @@ -550,17 +711,22 @@ __global__ void Marlin( int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; + + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; @@ -571,58 +737,54 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / - (w_type == vllm::kFE2M1f ? 2 : 1) + + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; if constexpr (has_zp) { if constexpr (group_blocks == -1) { zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { + } else if constexpr (group_blocks >= thread_k_blocks) { zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; } } auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); } else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; int bias_sh_rd; if constexpr (m_block_size_8) { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; } else { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; } @@ -638,12 +800,16 @@ __global__ void Marlin( if constexpr (has_zp) { if constexpr (is_zp_float) { if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + zp_sh_rd = + 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % tb_n_warps / 2) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } else { zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } } @@ -678,26 +844,19 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; constexpr int sh_b_size = stages * b_sh_stage; - int4* sh_b = sh; - int4* sh_red = sh; - + int4* sh_b = sh_new; + int4* sh_red = sh_new; constexpr int sh_size_b_red_min = (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); constexpr int sh_size_b_red_max = @@ -708,8 +867,8 @@ __global__ void Marlin( ? sh_size_b_red_max : (sh_size_b_red_min + sh_bias_size); - int4* sh_bias = sh + sh_size_b_red_min; - int4* sh_g_idx = sh + sh_b_red_bias_size; + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -723,7 +882,8 @@ __global__ void Marlin( // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; FragS frag_s[2][4]; // No act-order FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order @@ -731,6 +891,24 @@ __global__ void Marlin( FragZP frag_zp; // Zero-points in fp16 FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + if constexpr (is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + // Zero accumulators. auto zero_accums = [&]() { #pragma unroll @@ -788,15 +966,17 @@ __global__ void Marlin( } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = + b_gl_rd + (i % count) * threads + + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); - B_ptr[i] += b_gl_rd_delta_o; + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); } + b_gl_rd += b_gl_rd_delta_o; + if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; @@ -816,44 +996,24 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -891,14 +1051,14 @@ __global__ void Marlin( int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm( + ldsm( frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -922,53 +1082,54 @@ __global__ void Marlin( auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = - reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } - } else { + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / tb_n_warps; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = - k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != vllm::kFE2M1f.id()) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { + } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } else { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + - k % 2]; + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } } } @@ -989,18 +1150,15 @@ __global__ void Marlin( cur_k = 0; // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + cur_k += k % b_sh_wr_iters; // Determine "position" inside the thread-block (based on warp and // thread-id) auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; - cur_k += warp_row * 16; + cur_k += warp_row * 16 * b_sh_wr_iters; auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix @@ -1055,18 +1213,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 || is_a_8bit) { #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } - } else if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = @@ -1075,21 +1231,11 @@ __global__ void Marlin( } } else { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int warp_row = warp_id / tb_n_warps; - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1108,29 +1254,18 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + - zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } - } else { + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1141,33 +1276,46 @@ __global__ void Marlin( } }; - auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) { + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) return; int k2 = k % 2; + constexpr int g = + group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; const bool is_new_zp = - ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == 0) || + ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && + (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; int zp_quant_0, zp_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (b_type.size_bits() == 4) { zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = zp_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = frag_qzp[k2][1]; } - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); } } if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { @@ -1177,14 +1325,14 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales( - s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( - s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1195,61 +1343,168 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (b_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; - } else if constexpr (w_type.size_bits() == 4) { + } else if constexpr (b_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); } // Apply scale to frag_b0 - if constexpr (has_act_order) { + if constexpr (has_act_order && !is_a_8bit) { static_assert(group_blocks != -1); - scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && - group_blocks == -1) { + group_blocks == -1 && !is_a_8bit) { int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( + scalar_t2 s2 = Adtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && + !is_a_8bit) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } else if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; + #pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = + reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = + reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma(frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * + scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * + scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || + s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t* s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = + *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } } } } @@ -1263,7 +1518,8 @@ __global__ void Marlin( constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_stride = + b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); @@ -1278,7 +1534,8 @@ __global__ void Marlin( for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { @@ -1287,24 +1544,26 @@ __global__ void Marlin( float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); @@ -1320,10 +1579,10 @@ __global__ void Marlin( // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_o = 8 * c_gl_stride * (is_a_8bit ? 2 : 1); int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr; if constexpr (m_block_size_8) { @@ -1331,9 +1590,9 @@ __global__ void Marlin( 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; c_gl_wr += (2 * thread_n_blocks) * slice_col; } else { - c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) * (is_a_8bit ? 2 : 1) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); } constexpr int c_sh_wr_delta = active_threads; auto c_sh_wr = threadIdx.x; @@ -1351,6 +1610,14 @@ __global__ void Marlin( &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], (threadIdx.x % 4) * 2 + i < prob_m); + } else if constexpr (is_a_8bit) { + int2* sh_red_int2 = reinterpret_cast(sh_red); + int2* c_int2 = reinterpret_cast(C); + cp_async2_ca_pred( + &sh_red_int2[c_sh_wr + c_sh_wr_delta * i], + &c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } else { cp_async4_pred( &sh_red[c_sh_wr + c_sh_wr_delta * i], @@ -1370,36 +1637,51 @@ __global__ void Marlin( (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); if (mask) { if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_scalar_t* c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = + reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + + (i % 4) + delta] += Cdtype::num2float(c_red_f16[j]); } } if (!last) { - int4 c; + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + + (i % 4) + delta]); } - if constexpr (m_block_size_8) + if constexpr (m_block_size_8) { C[c_gl_wr + i * c_gl_stride + - (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; - else + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = + *reinterpret_cast(c_f16); + } else if constexpr (is_a_8bit) { + int2* c_int2 = reinterpret_cast(C); + c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)] = + *reinterpret_cast(c_f16); + } else { C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)] = c; + c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast(c_f16); + } } } } @@ -1414,10 +1696,10 @@ __global__ void Marlin( constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; constexpr int th_size = num_floats * sizeof(float) / 16; int c_cur_offset = locks_off * c_size; @@ -1471,7 +1753,7 @@ __global__ void Marlin( } else { c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); } int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + @@ -1481,47 +1763,47 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + c_scalar_t2 res = + Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && + b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - scalar_t2 tmp_scale = s[0]; + c_scalar_t2 tmp_scale = s[0]; if constexpr (m_block_size_8) { - tmp_scale = Dtype::num2num2( + tmp_scale = Cdtype::num2num2( reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); } res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { res = __hmul2(res, global_scale); } if (has_bias && last) { - scalar_t2 tmp_bias = b_bias[0]; + c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { - tmp_bias = Dtype::num2num2( + tmp_bias = Cdtype::num2num2( reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); } res = __hadd2(res, tmp_bias); } if constexpr (m_block_size_8) { - ((scalar_t*)sh_red)[idx] = res.x; - ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + ((c_scalar_t*)sh_red)[idx] = res.x; + ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; } else { - ((scalar_t2*)sh_red)[idx] = res; + ((c_scalar_t2*)sh_red)[idx] = res; } }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) { + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], @@ -1557,9 +1839,9 @@ __global__ void Marlin( i++) { if (c_gl_wr < c_gl_wr_end) { if (use_atomic_add && slice_count > 1) { - scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); + c_scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + c_scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); #pragma unroll for (int a = 0; a < 4; a++) { atomicAdd(&C_half2[a], sh_red_half2[a]); @@ -1635,7 +1917,13 @@ __global__ void Marlin( wait_for_stage(); init_same_group(pipe % stages); } - matmul(k); + + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); + } } slice_iters--; if (slice_iters == 0) { @@ -1668,13 +1956,47 @@ __global__ void Marlin( // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1692,20 +2014,27 @@ __global__ void Marlin( } if constexpr (!has_act_order && group_blocks == -1 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { cp_async_wait<0>(); __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; if constexpr (m_block_size_8) { int idx = (threadIdx.x / 4) % 2; - scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + c_scalar_t2* frag_s_half2 = + reinterpret_cast(frag_s); #pragma unroll for (int i = 0; i < 8; i++) { - frag_s_half2[i] = Dtype::num2num2( - reinterpret_cast(&frag_s_half2[i])[idx]); + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); } } } @@ -1715,26 +2044,48 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && + b_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); if constexpr (!m_block_size_8) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } @@ -1758,7 +2109,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; - reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + if constexpr (!is_a_8bit) + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; __syncthreads(); } @@ -1768,21 +2120,22 @@ __global__ void Marlin( // only the last block in a slice actually writes the result write_result(last); slice_row = 0; - slice_col_par++; - slice_col++; + if (!in_part2) { + slice_col_par += gridDim.x; + } else { + slice_col_par++; + slice_col++; + } is_first_matmul_in_slice = true; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } + a_gl_rd += a_gl_rd_delta_o * slice_row; + b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading @@ -1791,12 +2144,28 @@ __global__ void Marlin( slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } - start_pipes(); } } diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 2f52a6b7a024..9f02f4f17974 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -617,7 +617,7 @@ struct MacheteCollectiveMma { // Same as upstream, should be kept the same when possible, not formatted for // easier comparison - // with `SwapAB ? N : M -> M` since we dont support SwapAB + // with `SwapAB ? N : M -> M` since we don't support SwapAB // clang-format off template static bool diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index 49cafcc32adc..99fec8fd6feb 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, void get_cutlass_moe_mm_problem_sizes_caller( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets) { + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt) { auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); // Swap-AB should be disabled for FP4 path - bool may_swap_ab = (!blockscale_offsets.has_value()) && - (topk_ids.numel() <= SWAP_AB_THRESHOLD); + bool may_swap_ab = + force_swap_ab.value_or((!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD)); launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, atomic_buffer, num_experts, n, k, stream, diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 1001af05ff00..5de21cfbbaaf 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif -#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ - defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \ - defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120 +#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \ + (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \ + (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120) void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_problem_sizes_caller( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets); + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt); void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, @@ -284,8 +285,9 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k, @@ -296,26 +298,28 @@ void get_cutlass_moe_mm_data( false, "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void get_cutlass_moe_mm_problem_sizes( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets) { + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt) { int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, - blockscale_offsets); + blockscale_offsets, force_swap_ab); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " "kernel for CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, @@ -328,8 +332,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, num_local_experts, padded_m, n, k); @@ -339,7 +344,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, false, "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " "for CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu index e3ab0676b254..49d1b2086b8d 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) { return val; } +template +__device__ __forceinline__ float ComputeGroupScale( + const T* __restrict__ group_input, T* __restrict__ smem_group, + const int group_size, const int lane_id, const int threads_per_group, + const float eps, const float max_8bit) { + float local_absmax = eps; + + constexpr int vec_size = 16 / sizeof(T); + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(T & dst, const T& src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment( + group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax); + + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + } + + return y_s; +} + +template +__device__ __forceinline__ void QuantizeGroup( + const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output, + const int group_size, const int lane_id, const int threads_per_group, + const float y_s, const float min_8bit, const float max_8bit) { + constexpr int vec_size = 16 / sizeof(T); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = DST_DTYPE(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + template __global__ void per_token_group_quant_8bit_kernel( @@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel( const int64_t global_group_id = block_group_id + local_group_id; const int64_t block_group_offset = global_group_id * group_size; - float local_absmax = eps; - using scale_element_t = float; static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); @@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel( T* smem = reinterpret_cast(smem_raw); T* smem_group = smem + local_group_id * group_size; - constexpr int vec_size = 16 / sizeof(T); - using vec_t = vllm::vec_n_t; - - // copy global -> shared & compute absmax - auto scalar_op_cache = [&] __device__(T & dst, const T& src) { - float abs_v = fabsf(static_cast(src)); - local_absmax = fmaxf(local_absmax, abs_v); - dst = src; - }; - - vllm::vectorize_with_alignment( - group_input, // in - smem_group, // out (shared) - group_size, // elements per group - lane_id, // thread id - threads_per_group, // stride in group - scalar_op_cache); // scalar handler - - local_absmax = GroupReduceMax(local_absmax); - - float y_s = local_absmax / max_8bit; - if constexpr (SCALE_UE8M0) { - y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); - } + const float y_s = ComputeGroupScale( + group_input, smem_group, group_size, lane_id, threads_per_group, eps, + max_8bit); scale_element_t y_s_quant = y_s; @@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel( __syncthreads(); - // quantize shared -> global 8-bit - auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { - float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); - dst = DST_DTYPE(q); - }; + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); +} - vllm::vectorize_with_alignment( - smem_group, // in (shared) - group_output, // out (global quant tensor) - group_size, // elements - lane_id, // tid - threads_per_group, // stride - scalar_op_quant); // scalar handler +inline int GetGroupsPerBlock(int64_t num_groups) { + if (num_groups % 16 == 0) { + return 16; + } + if (num_groups % 8 == 0) { + return 8; + } + if (num_groups % 4 == 0) { + return 4; + } + if (num_groups % 2 == 0) { + return 2; + } + return 1; } void per_token_group_quant_8bit(const torch::Tensor& input, @@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input, constexpr int THREADS_PER_GROUP = 16; - int groups_per_block = 1; - - if (num_groups % 16 == 0) { - groups_per_block = 16; - } else if (num_groups % 8 == 0) { - groups_per_block = 8; - } else if (num_groups % 4 == 0) { - groups_per_block = 4; - } else if (num_groups % 2 == 0) { - groups_per_block = 2; - } + const int groups_per_block = GetGroupsPerBlock(num_groups); auto dst_type = output_q.scalar_type(); const int num_blocks = num_groups / groups_per_block; @@ -206,6 +234,148 @@ void per_token_group_quant_8bit(const torch::Tensor& input, #undef LAUNCH_KERNEL } +template +__global__ void per_token_group_quant_8bit_packed_kernel( + const T* __restrict__ input, void* __restrict__ output_q, + unsigned int* __restrict__ output_s_packed, const int group_size, + const int num_groups, const int groups_per_block, const int groups_per_row, + const int mn, const int tma_aligned_mn, const float eps, + const float min_8bit, const float max_8bit) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + if (global_group_id >= num_groups) { + return; + } + + const int64_t block_group_offset = global_group_id * group_size; + + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = + static_cast(output_q) + block_group_offset; + + // shared memory to cache each group's data to avoid double DRAM reads. + extern __shared__ __align__(16) char smem_raw[]; + T* smem = reinterpret_cast(smem_raw); + T* smem_group = smem + local_group_id * group_size; + const float y_s = + ComputeGroupScale(group_input, smem_group, group_size, lane_id, + threads_per_group, eps, max_8bit); + + // pack 4 scales into a uint32 + if (lane_id == 0) { + // map flat group id to 2D indices (mn_idx, sf_k_idx) + const int sf_k_idx = static_cast(global_group_id % groups_per_row); + const int mn_idx = static_cast(global_group_id / groups_per_row); + + if (mn_idx < mn) { + // each uint32 in output_s_packed stores 4 packed scales + const int sf_k_pack_idx = sf_k_idx / 4; + const int pos = sf_k_idx % 4; + + // reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit + // exponent, and place it into the correct byte of the 32-bit word. + const unsigned int bits = __float_as_uint(y_s); + const unsigned int exponent = (bits >> 23u) & 0xffu; + const unsigned int contrib = exponent << (pos * 8u); + + const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx; + // atomically OR 8-bit exponent into the packed scales buffer + atomicOr(output_s_packed + out_idx, contrib); + } + } + + __syncthreads(); + + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); +} + +void per_token_group_quant_8bit_packed(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output_q.is_contiguous()); + + const int64_t k = input.size(-1); + TORCH_CHECK(k % group_size == 0, "Last dimension (", k, + ") must be divisible by group_size (", group_size, ")."); + + const int64_t mn = input.numel() / k; + const int64_t groups_per_row = k / group_size; + const int64_t num_groups = mn * groups_per_row; + + TORCH_CHECK(output_s_packed.dim() == 2, + "output_s_packed must be 2D, got dim=", output_s_packed.dim(), + "."); + + const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4; + const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4; + + TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int, + "output_s_packed must have dtype int32 for UE8M0-packed scales."); + // DeepGEMM expects SFA scales in MN-major form with shape + // [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last + // dimension. + TORCH_CHECK(output_s_packed.size(0) == mn && + output_s_packed.size(1) == k_num_packed_sfk, + "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, + "], but got [", output_s_packed.size(0), ", ", + output_s_packed.size(1), "]."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + constexpr int THREADS_PER_GROUP = 16; + + const int groups_per_block = GetGroupsPerBlock(num_groups); + + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + // zero-initialize packed scales, since we use atomicOr to accumulate + // exponents from different groups. + output_s_packed.zero_(); + +#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + size_t smem_bytes = \ + static_cast(groups_per_block) * group_size * sizeof(T); \ + per_token_group_quant_8bit_packed_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + reinterpret_cast(output_s_packed.data_ptr()), \ + static_cast(group_size), static_cast(num_groups), \ + groups_per_block, static_cast(groups_per_row), \ + static_cast(mn), static_cast(tma_aligned_mn), \ + static_cast(eps), static_cast(min_8bit), \ + static_cast(max_8bit)); \ + } while (0) + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] { + if (dst_type == at::ScalarType::Float8_e4m3fn) { + LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3); + } else if (dst_type == at::ScalarType::Char) { + LAUNCH_PACKED_KERNEL(scalar_t, int8_t); + } else { + TORCH_CHECK( + false, + "per_token_group_quant_8bit_packed only supports FP8/INT8 " + "outputs."); + } + })); + +#undef LAUNCH_PACKED_KERNEL +} + void per_token_group_quant_fp8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double fp8_min, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 2ef579a1b753..8ebe55cef391 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, } #endif // defined(__HIP__GFX9__) TODO: Add NAVI support +// Find the min val of div2 that doesn't increase N/(div1*div2) int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; - int rnds0 = N / nPrRnd; - nPrRnd -= div1 * 3; - int rnds3 = N / nPrRnd; - nPrRnd -= div1; - int rnds4 = N / nPrRnd; - nPrRnd -= div1; - int rnds5 = N / nPrRnd; - nPrRnd -= div1; - int rnds6 = N / nPrRnd; - nPrRnd -= div1; - int rnds7 = N / nPrRnd; - nPrRnd -= div1; - int rnds8 = N / nPrRnd; - nPrRnd -= div1; - int rnds9 = N / nPrRnd; - nPrRnd -= div1; - int rtn = div2; - if (rnds0 == rnds3) rtn = div2 - 3; - if (rnds0 == rnds4) rtn = div2 - 4; - if (rnds0 == rnds5) rtn = div2 - 5; - if (rnds0 == rnds6) rtn = div2 - 6; - if (rnds0 == rnds7) rtn = div2 - 7; - if (rnds0 == rnds8) rtn = div2 - 8; - if (rnds0 == rnds9) rtn = div2 - 9; - return rtn; + int rnds[13]; + for (int i = 0; i < 13; i++) { + rnds[i] = (N + nPrRnd - 1) / nPrRnd; + nPrRnd -= div1; + } + for (int i = 12; i >= 0; i--) + if (rnds[0] == rnds[i]) return (div2 - i); } torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, @@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size() / 2; -#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } else if (K_in * N_in <= max_lds_len * 1.2) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSplitK_hf_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } else { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSplitK_hf_big_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } \ +#define WVSPLITK(_YTILE, _UNRL, _N) \ + { \ + dim3 block(64, 16); \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ + wvSplitK_hf_sml_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + else if (K_in * N_in <= max_lds_len * 1.2) \ + wvSplitK_hf_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + else \ + wvSplitK_hf_big_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + } + +#define WVSPLIT_TILE(_sYT, __N) \ + { \ + bool fit_lds = (K_in * N_in <= max_lds_len); \ + if (_sYT <= 1) \ + WVSPLITK(1, 4, __N) \ + else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \ + WVSPLITK(2, 2, __N) \ + else if (_sYT <= 4 * 3) \ + WVSPLITK(3, 2, __N) \ + else if (__N == 4) \ + WVSPLITK(4, 1, __N) \ + else \ + WVSPLITK(4, 2, __N) \ } AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { @@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ? reinterpret_cast(in_bias->data_ptr()) : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); + + // first shoot for biggest tile-size that keeps all simd busy, + // then cut the active waves to balance their distribution... + int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4); + switch (N_in) { case 1: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) + WVSPLIT_TILE(sYT, 1) break; case 2: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) + WVSPLIT_TILE(sYT, 2) break; case 3: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) + WVSPLIT_TILE(sYT, 3) break; case 4: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) + WVSPLIT_TILE(sYT, 4) break; default: throw std::runtime_error( diff --git a/csrc/sampler.cu b/csrc/sampler.cu index 410b8988f493..fc2154beff9e 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -44,41 +44,300 @@ __global__ void apply_repetition_penalties_kernel( } } -static inline __device__ uint16_t extractBinIdx(float x) { - union { - __half h; - uint16_t u16; - } tmp; - tmp.h = __float2half_rn(x); - tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); - return 511 - (tmp.u16 >> 7); +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; } -template -__device__ void topKPerRowJob(const float* logits, const int rowStart, - const int rowEnd, const int rowIdx, - int* outIndices, int stride0, int stride1) { - // The number of elements per thread for the final top-k sort. - static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; - // The class to sort the elements during the final top-k sort. - using TopKSort = cub::BlockRadixSort; +template +static inline __device__ uint32_t extractBinIdx(float x) { + if constexpr (step == 0) { + __half hx = __float2half(x); + uint16_t bits = __half_as_ushort(hx); + bits = (bits & 0x8000) ? bits : ~bits & 0x7fff; + return bits >> 5; + } else { + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + + if constexpr (step == 1) { + return bits >> 21; + } else if constexpr (step == 2) { + return (bits >> 10) & 0x7ff; + } else if constexpr (step == 3) { + return bits & 0x3ff; + } + } +} + +template +static inline __device__ bool isPartialMatch(float x, uint32_t pattern) { + if constexpr (shift == 0) { + return true; + } + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + return (bits ^ pattern) >> shift == 0; +} + +/** + * Map a Func over the input data, using vectorized load instructions if + * possible. + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, + const T* in, idxT len, Func f) { + constexpr int WARP_SIZE = 32; + using WideT = float4; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = thread_rank; i < len; i += num_threads) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / + sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT* in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + for (idxT i = thread_rank; i < len_cast; i += num_threads) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if (thread_rank < skip_cnt) { + f(in[thread_rank], thread_rank); + } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WARP_SIZE no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + if (remain_i < len) { + f(in[remain_i], remain_i); + } + } +} + +template +__device__ bool processHistogramStep( + const int* indices, const float* logits, int rowEnd, uint32_t& logitPattern, + int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx, + int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues, + SmemFinalType& smemFinal, int stride1, int rowStart, int topK) { + // Clear the histogram. +#pragma unroll + for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) { + smemFinal.histo.data[idx] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Update pattern + constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10; + if constexpr (step == 2) { + logitPattern = static_cast(thresholdBinIdx & 0x7ff) + << patternShift; + } else if constexpr (step == 3) { + logitPattern |= static_cast(thresholdBinIdx & 0x7ff) + << patternShift; + } + + auto distributeToBins = [&](float logit, int /* idx */ = 0) { + if (isPartialMatch(logit, logitPattern)) { + uint32_t binIdx = extractBinIdx(logit); + atomicAdd(&smemFinal.histo.data[binIdx], 1); + } + }; + + // Distribute the elements to the histogram bins. + if (stride1 == 1) { + vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, + rowEnd - rowStart, distributeToBins); + } else { + for (int idx = rowStart + threadIdx.x; idx < rowEnd; + idx += kNumThreadsPerBlock) { + float logit = logits[idx * stride1]; + distributeToBins(logit, idx); + } + } + // Make sure the histogram is ready. + __syncthreads(); + + // Reads the value of the starting position in the smemOutput array + int lastValue = smemFoundTopKValues[0]; + + for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) { + // Read the values from SMEM. + int idx = threadIdx.x + kNumThreadsPerBlock * round; + int binCount{0}; + binCount = smemFinal.histo.data[idx]; + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + using Scan = cub::BlockScan; + Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + prefixSum += lastValue; + totalSum += lastValue; + smemFinal.histo.data[idx] = prefixSum; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + bool foundThreshold = false; + if (prefixSum < topK) { + int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1 + ? totalSum + : smemFinal.histo.data[idx + 1]; + + if (nextPrefixSum >= topK) { + smemThresholdBinIdx[0] = idx; + smemFinalBinSize[0] = nextPrefixSum - prefixSum; + foundThreshold = true; + } + } + + // Early exit: if any thread found the threshold, we can skip remaining + // rounds + if (__syncthreads_or(foundThreshold)) { + break; + } + + lastValue = totalSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + thresholdBinIdx = smemThresholdBinIdx[0]; + + auto processBins = [&](float logit, int idx) { + if (isPartialMatch(logit, logitPattern)) { + uint32_t binIdx = extractBinIdx(logit); + if (binIdx < thresholdBinIdx) { + // The element is part of the top-k selection + int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1); + + if constexpr (mergeBlocks) { + smemOutput[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemOutput[dstIdx] = idx + rowStart; + reinterpret_cast(smemOutput + topK)[dstIdx] = logit; + } else { + smemOutput[dstIdx] = idx; + } + } + if constexpr (step < 3) { + // Only fill the final items for sorting if the threshold bin fits + if (binIdx == thresholdBinIdx && + smemFinalBinSize[0] <= kNumFinalItems) { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + smemFinal.items.logits[dstIdx] = logit; + if constexpr (mergeBlocks) { + smemFinal.items.indices[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemFinal.items.indices[dstIdx] = idx + rowStart; + } else { + smemFinal.items.indices[dstIdx] = idx; + } + } + } else { + if (binIdx == thresholdBinIdx) { + // The elements in the threshold bin share the same 32 bits at step 3 + int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1); + if (dstIdx < topK) { + if constexpr (mergeBlocks) { + smemOutput[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemOutput[dstIdx] = idx + rowStart; + reinterpret_cast(smemOutput + topK)[dstIdx] = logit; + } else { + smemOutput[dstIdx] = idx; + } + } + } + } + } + }; + + if (stride1 == 1) { + vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, + rowEnd - rowStart, processBins); + } else { + for (int idx = rowStart + threadIdx.x; idx < rowEnd; + idx += kNumThreadsPerBlock) { + float logit = logits[idx * stride1]; + processBins(logit, idx); + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // Check if we should continue to next step + return smemFinalBinSize[0] > kNumFinalItems; +} +// Follows half - 11 - 11 - 10 bit iterations +template +static __device__ void topKPerRowJob(const int* indices, const float* logits, + int rowStart, int rowEnd, int* outIndices, + float* outLogits, int stride1, int topK) { // The number of slots for the final pass. - static constexpr int kNumFinalItems = 3072; + static constexpr int kNumFinalItems = 2048; // The number of elements per thread for the final sort. static constexpr int kNumFinalItemsPerThread = kNumFinalItems / kNumThreadsPerBlock; // The class to sort the elements during the final pass. using FinalSort = cub::BlockRadixSort; - + using FinalSortTempStorage = + std::conditional_t; // The class to compute the inclusive prefix-sum over the histogram. using Scan = cub::BlockScan; - // Shared memory to compute the block scan. - __shared__ typename Scan::TempStorage smemScan; - // The structure to store the final items (for the final pass). struct FinalItems { // Shared memory to store the indices for the final pass. @@ -87,200 +346,225 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart, float logits[kNumFinalItems]; }; + struct Histogram { + typename Scan::TempStorage scan; + int data[kNumBins]; + }; + // Shared memory to compute the block sort. __shared__ union { FinalItems items; - typename FinalSort::TempStorage finalSort; - typename TopKSort::TempStorage topKSort; + FinalSortTempStorage finalSort; + Histogram histo; } smemFinal; - // Shared memory to store the histogram. - __shared__ int smemHistogram[kNumBins]; // Shared memory to store the selected indices. - __shared__ int smemIndices[kTopK]; + // If we are processing using multiple blocks, we need to store the logits and + // indices. + extern __shared__ int32_t smemOutput[]; + // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; + // Shared memory to determine if the threshold bin fits in the final items. + __shared__ int smemFinalBinSize[1]; + // Shared memory to keep track of the top-k values found so far by the + // previous iterations + __shared__ int smemFoundTopKValues[1]; // The length of the row. int rowLen = rowEnd - rowStart; // Shortcut if the length of the row is smaller than Top-K. Indices are not // sorted by their corresponding logit. - if (rowLen <= kTopK) { + if (rowLen <= topK) { for (int rowIt = threadIdx.x; rowIt < rowLen; rowIt += kNumThreadsPerBlock) { - int idx = rowStart + rowIt; - outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + if constexpr (multipleBlocksPerRow) { + outIndices[rowIt] = rowIt + rowStart; + outLogits[rowIt] = logits[rowIt + rowStart]; + } else { + outIndices[rowIt] = rowIt; + } } - for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; + for (int rowIt = rowLen + threadIdx.x; rowIt < topK; rowIt += kNumThreadsPerBlock) { - outIndices[rowIdx * kTopK + rowIt] = -1; + outIndices[rowIt] = -1; + if constexpr (multipleBlocksPerRow) { + outLogits[rowIt] = -FLT_MAX; + } } - return; - } - - // Clear the histogram. - if (threadIdx.x < kNumBins) { - smemHistogram[threadIdx.x] = 0; - } - - // Make sure the histogram is ready. - __syncthreads(); - - // Fetch elements one-by-one. - for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; - rowIt += kNumThreadsPerBlock) { - uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); - atomicAdd(&smemHistogram[idx], 1); - } - - // Make sure the histogram is ready. - __syncthreads(); - - // Read the values from SMEM. - int binCount{0}; - if (threadIdx.x < kNumBins) { - binCount = smemHistogram[threadIdx.x]; - } - - // Make sure each thread has read its value. - __syncthreads(); - - // Compute the prefix sum. - int prefixSum{0}, totalSum{0}; - Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); - - // Update the histogram with the prefix sums. - if (threadIdx.x < kNumBins) { - smemHistogram[threadIdx.x] = prefixSum; - } - // Make sure the data is in shared memory. - __syncthreads(); - - // Find the last valid bin. - if (threadIdx.x < kNumBins) { - int nextPrefixSum = - threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; - if (prefixSum < kTopK && nextPrefixSum >= kTopK) { - smemThresholdBinIdx[0] = threadIdx.x; - } + return; } - - // Clear the counter to store the items for the final phase. + // Initialize values if (threadIdx.x == 0) { smemFinalDstIdx[0] = 0; + smemFoundTopKValues[0] = 0; } - - // Make sure the data is in shared memory. __syncthreads(); + int thresholdBinIdx = -1; + uint32_t logitPattern = 0; + + // Step 0: Process first 11 bits of half representation + bool continueToNextStep = + processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + + if (continueToNextStep) { + // Step 1: Process next 11 bits + continueToNextStep = + processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + } - // The threshold bin. - int thresholdBinIdx = smemThresholdBinIdx[0]; - - // Fetch elements one-by-one and populate the shared memory buffers. - for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; - rowIt += kNumThreadsPerBlock) { - float logit = logits[rowIdx * stride0 + rowIt * stride1]; - uint16_t idx = extractBinIdx(logit); - if (idx < thresholdBinIdx) { - int dstIdx = atomicAdd(&smemHistogram[idx], 1); - smemIndices[dstIdx] = rowIt; - } else if (idx == thresholdBinIdx) { - int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); - if (dstIdx < kNumFinalItems) { - smemFinal.items.logits[dstIdx] = logit; - smemFinal.items.indices[dstIdx] = rowIt; - } - } + if (continueToNextStep) { + // Step 2: Process next 11 bits + continueToNextStep = + processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); } - // Make sure the elements are in shared memory. - __syncthreads(); + if (continueToNextStep) { + // Step 3: Process last 10 bits + processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + } - // The logits of the elements to be sorted in the final pass. - float finalLogits[kNumFinalItemsPerThread]; - // The indices of the elements to be sorted in the final pass. - int finalIndices[kNumFinalItemsPerThread]; + if (!continueToNextStep) { + // The histogram did not proceed to the final 10 bits, therefore we need to + // sort the final items The logits of the elements to be sorted in the final + // pass. + if constexpr (useRadixSort) { + // Sorting with radix sort + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; -// Init. #pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - finalLogits[ii] = -FLT_MAX; - } + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + finalLogits[ii] = -FLT_MAX; + } -// Read the elements from SMEM. + // Read the elements from SMEM. #pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - if (srcIdx < smemFinalDstIdx[0]) { - finalLogits[ii] = smemFinal.items.logits[srcIdx]; - finalIndices[ii] = smemFinal.items.indices[srcIdx]; - } - } + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if (srcIdx < smemFinalDstIdx[0]) { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + // Make sure the shared memory has been read. + __syncthreads(); - // Make sure the shared memory has been read. - __syncthreads(); + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); - // Sort the elements. - FinalSort(smemFinal.finalSort) - .SortDescendingBlockedToStriped(finalLogits, finalIndices); + // Copy the data back to the shared memory storage. + int baseIdx = smemFoundTopKValues[0]; - // Copy the data back to the shared memory storage. - int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; #pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - int dstIdx = baseIdx + srcIdx; - if (dstIdx < kTopK) { - smemIndices[dstIdx] = finalIndices[ii]; + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + + if (dstIdx < topK) { + smemOutput[dstIdx] = finalIndices[ii]; + if constexpr (multipleBlocksPerRow) { + reinterpret_cast(smemOutput + topK)[dstIdx] = + finalLogits[ii]; + } + } + } + } else { + // Sorting with insertion sort + auto baseIdx = smemFoundTopKValues[0]; + for (int i = threadIdx.x; i < smemFinalDstIdx[0]; + i += kNumThreadsPerBlock) { + int outIndex = 0; + auto logit = smemFinal.items.logits[i]; + for (int j = 0; j < smemFinalDstIdx[0]; j++) { + auto otherLogit = smemFinal.items.logits[j]; + if (logit < otherLogit || (logit == otherLogit && i < j)) { + outIndex++; + } + } + // Store if outIndex is in bounds + if (outIndex + baseIdx < topK) { + smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i]; + if constexpr (multipleBlocksPerRow) { + reinterpret_cast(smemOutput + topK)[outIndex + baseIdx] = + smemFinal.items.logits[i]; + } + } + } } + __syncthreads(); } - // Make sure the data is in shared memory. - __syncthreads(); - -// Store to global memory. -#pragma unroll - for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { - int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; - outIndices[offset] = - smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; + // Store to global memory. + for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) { + if constexpr (multipleBlocksPerRow) { + outIndices[i] = smemOutput[i]; + outLogits[i] = reinterpret_cast(smemOutput + topK)[i]; + } else { + if (stride1 == 1) { + // stride1 == 1 will use vectorized_process, which indexes already skip + // the rowStart. + outIndices[i] = smemOutput[i]; + } else { + outIndices[i] = smemOutput[i] - rowStart; + } + } } } -template -static __global__ void topKPerRow(const float* logits, const int* rowStarts, - const int* rowEnds, int* outIndices, - int stride0, int stride1) { +template +static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill( + const float* logits, const int* rowStarts, const int* rowEnds, + int* outIndices, int stride0, int stride1, const int topK, + const int offsetIndex) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; + static constexpr int kNumBins = 2048; // The row computed by this block. - int rowIdx = blockIdx.x; + int rowIdx = blockIdx.x + offsetIndex; // The range of logits within the row. int rowStart = rowStarts[rowIdx]; int rowEnd = rowEnds[rowIdx]; - topKPerRowJob( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + outIndices += rowIdx * topK; + logits += rowIdx * stride0; + + topKPerRowJob( + nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); } -template -static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, - int* outIndices, int stride0, - int stride1, int next_n) { +template +static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( + const float* logits, const int* seqLens, int* outIndices, int stride0, + int stride1, const int topK, int next_n, float* outLogits = nullptr, + const int numBlocksToMerge = 0, const int* indices = nullptr) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; + static constexpr int kNumBins = 2048; // The row computed by this block. int rowIdx = blockIdx.x; @@ -290,8 +574,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, int seq_len = seqLens[rowIdx / next_n]; int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; - topKPerRowJob( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + if constexpr (!multipleBlocksPerRow && !mergeBlocks) { + outIndices += rowIdx * topK; + } else if constexpr (multipleBlocksPerRow) { + const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 + rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 + rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; + outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; + outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; + } else if constexpr (mergeBlocks) { + rowEnd = numBlocksToMerge * topK; + indices += rowIdx * numBlocksToMerge * topK; + outIndices += rowIdx * topK; + } + logits += rowIdx * stride0; + + topKPerRowJob( + indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK); } } // namespace vllm @@ -339,28 +640,84 @@ void apply_repetition_penalties_( void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, const torch::Tensor& seqLens, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1) { - // Compute the results on the device. + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK) { + constexpr int kSortingAlgorithmThreshold = 12288; + constexpr int kSplitWorkThreshold = 200 * 1000; constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - vllm::topKPerRowDecode - <<>>( - logits.data_ptr(), seqLens.data_ptr(), - indices.data_ptr(), static_cast(stride0), - static_cast(stride1), static_cast(next_n)); + const auto numColumns = logits.size(1); + + if (numColumns < kSortingAlgorithmThreshold) { + // Use insertion sort + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n)); + } else if (numColumns < kSplitWorkThreshold) { + // From this threshold, use radix sort instead + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n)); + } else { + // Long sequences are run in two steps + constexpr auto multipleBlocksPerRowConfig = 10; + + const auto outIndicesAux = + torch::empty({numRows, multipleBlocksPerRowConfig, topK}, + torch::dtype(torch::kInt32).device(logits.device())); + const auto outLogitsAux = + torch::empty({numRows, multipleBlocksPerRowConfig, topK}, + torch::dtype(torch::kFloat).device(logits.device())); + + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + outIndicesAux.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n), outLogitsAux.data_ptr()); + + constexpr int kNumThreadsPerBlockMerge = 1024; + vllm::topKPerRowDecode + <<>>( + outLogitsAux.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), multipleBlocksPerRowConfig * topK, 1, + static_cast(topK), static_cast(next_n), nullptr, + multipleBlocksPerRowConfig, outIndicesAux.data_ptr()); + } } -void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1) { - // Compute the results on the device. +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK) { + constexpr int kSortingAlgorithmThreshold = 12288; constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::topKPerRow - <<>>( - logits.data_ptr(), rowStarts.data_ptr(), - rowEnds.data_ptr(), indices.data_ptr(), - static_cast(stride0), static_cast(stride1)); + int numInsertionBlocks = + std::min(static_cast(numRows), kSortingAlgorithmThreshold); + vllm::topKPerRowPrefill + <<>>(logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1), + static_cast(topK), 0); + + if (numRows > kSortingAlgorithmThreshold) { + int numRadixBlocks = numRows - kSortingAlgorithmThreshold; + vllm::topKPerRowPrefill + <<>>(logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1), + static_cast(topK), kSortingAlgorithmThreshold); + } } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 14913bef1312..83d4943d6277 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); -#ifndef USE_ROCM // Merge attn states // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) @@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); - +#ifndef USE_ROCM ops.def( "convert_vertical_slash_indexes(" " Tensor! block_count, Tensor! block_offset, " @@ -180,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Optimized top-k per row operation ops.def( - "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " + "top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, " "Tensor! indices, int numRows, int stride0, " - "int stride1) -> ()"); - ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + "int stride1, int topK) -> ()"); + ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill); ops.def( "top_k_per_row_decode(Tensor logits, int next_n, " - "Tensor seq_lens, Tensor! indices, int numRows, " - "int stride0, int stride1) -> ()"); + "Tensor seq_lens, Tensor! indices, " + "int numRows, int stride0, int stride1, int topK) -> ()"); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); // Layernorm-quant @@ -216,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, &rms_norm_dynamic_per_token_quant); + // Fused Layernorm + Block quant kernels + ops.def( + "rms_norm_per_block_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual, int group_size, " + "bool is_scale_transposed) -> ()"); + ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( @@ -299,9 +306,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor? b_bias_or_none," - "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " - "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " + "Tensor? b_bias_or_none,Tensor b_scales, " + "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, " + "Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); // conditionally compiled so impl registration is in source file @@ -309,13 +317,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin repack from GPTQ. ops.def( "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " - "SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); + "SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor"); // conditionally compiled so impl registrations are in source file // awq_marlin repack from AWQ. ops.def( "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " - "SymInt size_n, int num_bits) -> Tensor"); + "SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor"); + // conditionally compiled so impl registrations are in source file + + // preprocess W-int4A-fp8 weight for marlin kernel + ops.def( + "marlin_int4_fp8_preprocess(Tensor qweight, " + "Tensor? qzeros_or_none, bool inplace) -> Tensor"); // conditionally compiled so impl registrations are in source file // CUTLASS w4a8 GEMM @@ -336,6 +350,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); // conditionally compiled so impl registration is in source file + // CUTLASS w4a8 grouped GEMM + ops.def( + "cutlass_w4a8_moe_mm(" + " Tensor! out_tensors," + " Tensor a_tensors," + " Tensor b_tensors," + " Tensor a_scales," + " Tensor b_scales," + " Tensor b_group_scales," + " int b_group_size," + " Tensor expert_offsets," + " Tensor problem_sizes," + " Tensor a_strides," + " Tensor b_strides," + " Tensor c_strides," + " Tensor group_scale_strides," + " str? maybe_schedule" + ") -> ()"); + ops.def( + "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " + "Tensor)"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -452,7 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes1, " " Tensor! problem_sizes2, " " int num_experts, int n, int k, " - " Tensor? blockscale_offsets) -> ()"); + " Tensor? blockscale_offsets, " + " bool? force_swap_ab) -> ()"); ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, &get_cutlass_moe_mm_problem_sizes); @@ -611,6 +649,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("per_token_group_fp8_quant", torch::kCUDA, &per_token_group_quant_fp8); + // Compute per-token-group 8-bit quantized tensor and UE8M0-packed, + // TMA-aligned scales for DeepGEMM. + ops.def( + "per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, " + "Tensor! output_s_packed, int group_size, float eps, float fp8_min, " + "float fp8_max) -> ()"); + ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA, + &per_token_group_quant_8bit_packed); + // Compute per-token-group INT8 quantized tensor and scaling factor. ops.def( "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " @@ -707,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + cache_ops.def( + "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, " + "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int " + "batch_size) -> ()"); + cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA, + &cp_gather_and_upconvert_fp8_kv_cache); + cache_ops.def( "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "slot_mapping, " diff --git a/docker/Dockerfile b/docker/Dockerfile index 84a1802dbe03..0d50d97e54c6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -150,8 +150,8 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} #################### BASE BUILD IMAGE #################### -#################### WHEEL BUILD IMAGE #################### -FROM base AS build +#################### CSRC BUILD IMAGE #################### +FROM base AS csrc-build ARG TARGETPLATFORM ARG PIP_INDEX_URL UV_INDEX_URL @@ -172,10 +172,13 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -COPY . . -ARG GIT_REPO_CHECK=0 -RUN --mount=type=bind,source=.git,target=.git \ - if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi +WORKDIR /workspace + +COPY pyproject.toml setup.py CMakeLists.txt ./ +COPY cmake cmake/ +COPY csrc csrc/ +COPY vllm/envs.py vllm/envs.py +COPY vllm/__init__.py vllm/__init__.py # max jobs used by Ninja to build extensions ARG max_jobs=2 @@ -193,11 +196,14 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0 # Flag to control whether to use pre-built vLLM wheels ARG VLLM_USE_PRECOMPILED="" +ARG VLLM_MERGE_BASE_COMMIT="" ARG VLLM_MAIN_CUDA_VERSION="" +# Use dummy version for csrc-build wheel (only .so files are extracted, version doesn't matter) +ENV SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0+csrc.build" + # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ && curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \ @@ -211,6 +217,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ + && export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" \ && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ && export VLLM_DOCKER_BUILD_CONTEXT=1 \ && sccache --show-stats \ @@ -223,15 +230,61 @@ ENV VLLM_TARGET_DEVICE=${vllm_target_device} ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" != "1" ]; then \ # Clean any existing CMake artifacts rm -rf .deps && \ mkdir -p .deps && \ export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \ + export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" && \ export VLLM_DOCKER_BUILD_CONTEXT=1 && \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ fi +#################### CSRC BUILD IMAGE #################### + +#################### WHEEL BUILD IMAGE #################### +FROM base AS build +ARG TARGETPLATFORM + +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL + +# install build dependencies +COPY requirements/build.txt requirements/build.txt + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + +WORKDIR /workspace + +COPY --from=csrc-build /workspace/dist /precompiled-wheels + +COPY . . + +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi + +ARG vllm_target_device="cuda" +ENV VLLM_TARGET_DEVICE=${vllm_target_device} + +# Skip adding +precompiled suffix to version (preserves git-derived version) +ENV VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX=1 + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=.git,target=.git \ + if [ "${vllm_target_device}" = "cuda" ]; then \ + export VLLM_PRECOMPILED_WHEEL_LOCATION=$(ls /precompiled-wheels/*.whl); \ + fi && \ + python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 # Install DeepGEMM from source ARG DEEPGEMM_GIT_REF @@ -244,9 +297,15 @@ RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh # Install EP kernels(pplx-kernels and DeepEP) +ARG PPLX_COMMIT_HASH +ARG DEEPEP_COMMIT_HASH RUN --mount=type=cache,target=/root/.cache/uv \ export TORCH_CUDA_ARCH_LIST='9.0a 10.0a' && \ - /tmp/install_python_libraries.sh /tmp/ep_kernels_workspace wheel && \ + /tmp/install_python_libraries.sh \ + --workspace /tmp/ep_kernels_workspace \ + --mode wheel \ + ${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \ + ${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} && \ find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete # Check the size of the wheel if RUN_WHEEL_CHECK is true @@ -358,7 +417,12 @@ RUN CUDA_VERSION_DASH=$(echo $CUDA_VERSION | cut -d. -f1,2 | tr '.' '-') && \ cuda-cudart-${CUDA_VERSION_DASH} \ cuda-nvrtc-${CUDA_VERSION_DASH} \ cuda-cuobjdump-${CUDA_VERSION_DASH} \ - libcublas-${CUDA_VERSION_DASH} && \ + # https://github.com/vllm-project/vllm/issues/29590 + libcurand-dev-${CUDA_VERSION_DASH} \ + libcublas-${CUDA_VERSION_DASH} \ + # Fixes nccl_allocator requiring nccl.h at runtime + # https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22 + libnccl-dev && \ rm -rf /var/lib/apt/lists/* ARG PIP_INDEX_URL UV_INDEX_URL @@ -392,8 +456,8 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer pre-compiled kernel cache and binaries # https://docs.flashinfer.ai/installation.html RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system flashinfer-cubin==0.5.2 \ - && uv pip install --system flashinfer-jit-cache==0.5.2 \ + uv pip install --system flashinfer-cubin==0.5.3 \ + && uv pip install --system flashinfer-jit-cache==0.5.3 \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ && flashinfer show-config @@ -516,7 +580,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0' + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3' ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index eb3807ef0ca4..8d55ecfba3e5 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -132,7 +132,7 @@ RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ esac; \ }; \ remove_packages_not_supported_on_aarch64 && \ - sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \ + sed -i 's/^torch==.*/torch==2.9.1/g' requirements/cpu-test.in && \ sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 6d19dc601f10..a72a6b6e5cac 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -65,6 +65,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1 # ----------------------- # Test vLLM image @@ -88,10 +89,22 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace # install development dependencies (for testing) RUN cd /vllm-workspace \ - && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ && python3 -m pip install pytest-shard +# enable fast downloads from hf (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system hf_transfer +ENV HF_HUB_ENABLE_HF_TRANSFER=1 + +# Copy in the v1 package (for python-only install test group) +COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 + +# Source code is used in the `python_only_compile.sh` test +# We hide it inside `src/` so that this source code +# will not be imported by other tests +RUN mkdir src && mv vllm src/vllm + # ----------------------- # Final vLLM image FROM base AS final diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index df4f9b6c26e7..a57ee728d924 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -5,6 +5,8 @@ ARG PYTORCH_BRANCH="1c57644d" ARG PYTORCH_VISION_BRANCH="v0.23.0" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" +ARG PYTORCH_AUDIO_BRANCH="v2.9.0" +ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG AITER_BRANCH="59bd8ff2" @@ -23,6 +25,7 @@ ENV AITER_ROCM_ARCH=gfx942;gfx950 ENV HSA_NO_SCRATCH_RECLAIM=1 ARG PYTHON_VERSION=3.12 +ENV PYTHON_VERSION=${PYTHON_VERSION} RUN mkdir -p /app WORKDIR /app @@ -45,6 +48,7 @@ RUN apt-get update -y \ && python3 --version && python3 -m pip --version RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython +RUN apt-get update && apt-get install -y libjpeg-dev libsox-dev libsox-fmt-all sox && rm -rf /var/lib/apt/lists/* FROM base AS build_triton ARG TRITON_BRANCH @@ -66,11 +70,14 @@ RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install FROM base AS build_pytorch ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_AUDIO_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO +ARG PYTORCH_AUDIO_REPO + RUN git clone ${PYTORCH_REPO} pytorch -RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ - pip install -r requirements.txt && git submodule update --init --recursive \ +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \ + && pip install -r requirements.txt && git submodule update --init --recursive \ && python3 tools/amd_build/build_amd.py \ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl @@ -78,8 +85,15 @@ RUN git clone ${PYTORCH_VISION_REPO} vision RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl +RUN git clone ${PYTORCH_AUDIO_REPO} audio +RUN cd audio && git checkout ${PYTORCH_AUDIO_BRANCH} \ + && git submodule update --init --recursive \ + && pip install -r requirements.txt \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ - && cp /app/vision/dist/*.whl /app/install + && cp /app/vision/dist/*.whl /app/install \ + && cp /app/audio/dist/*.whl /app/install FROM base AS build_fa ARG FA_BRANCH @@ -130,6 +144,8 @@ ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO +ARG PYTORCH_AUDIO_BRANCH +ARG PYTORCH_AUDIO_REPO ARG FA_BRANCH ARG FA_REPO ARG AITER_BRANCH @@ -141,7 +157,9 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_AUDIO_BRANCH: ${PYTORCH_AUDIO_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_AUDIO_REPO: ${PYTORCH_AUDIO_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ - && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index adac43c6accb..72d2053102c2 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -76,6 +76,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils ENV NIXL_VERSION=0.7.0 RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +# PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts +RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf + # remove torch bundled oneccl to avoid conflicts RUN --mount=type=cache,target=/root/.cache/pip \ pip uninstall oneccl oneccl-devel -y diff --git a/docs/.nav.yml b/docs/.nav.yml index c8bf00efb237..835cc773e759 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -5,11 +5,7 @@ nav: - Getting Started: - getting_started/quickstart.md - getting_started/installation - - Examples: - - examples/README.md - - Offline Inference: examples/offline_inference - - Online Serving: examples/online_serving - - Others: examples/others + - Examples: examples - General: - usage/v1_guide.md - usage/* @@ -52,12 +48,18 @@ nav: - Plugins: - design/*plugin*.md - design/* + - Benchmarking: + - benchmarking/README.md + - benchmarking/cli.md + - benchmarking/sweeps.md + - benchmarking/dashboard.md - API Reference: - api/README.md - api/vllm - CLI Reference: cli - Community: - community/* + - Governance: governance - Blog: https://blog.vllm.ai - Forum: https://discuss.vllm.ai - Slack: https://slack.vllm.ai diff --git a/docs/api/README.md b/docs/api/README.md index d3a141f32730..d51329ec2faa 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes. - [vllm.config.MultiModalConfig][] - [vllm.config.PoolerConfig][] - [vllm.config.StructuredOutputsConfig][] +- [vllm.config.ProfilerConfig][] - [vllm.config.ObservabilityConfig][] - [vllm.config.KVTransferConfig][] - [vllm.config.CompilationConfig][] diff --git a/docs/benchmarking/README.md b/docs/benchmarking/README.md new file mode 100644 index 000000000000..238290d4762b --- /dev/null +++ b/docs/benchmarking/README.md @@ -0,0 +1,7 @@ +# Benchmark Suites + +vLLM provides comprehensive benchmarking tools for performance testing and evaluation: + +- **[Benchmark CLI](./cli.md)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing. +- **[Parameter Sweeps](./sweeps.md)**: Automate `vllm bench` runs for multiple configurations, useful for [optimization and tuning](../configuration/optimization.md). +- **[Performance Dashboard](./dashboard.md)**: Automated CI that publishes benchmarks on each commit. diff --git a/docs/contributing/benchmarks.md b/docs/benchmarking/cli.md similarity index 72% rename from docs/contributing/benchmarks.md rename to docs/benchmarking/cli.md index c9bc9cfe28a3..dd5a12e408b0 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/benchmarking/cli.md @@ -1,22 +1,10 @@ ---- -toc_depth: 4 ---- +# Benchmark CLI -# Benchmark Suites +This section guides you through running benchmark tests with the extensive datasets supported on vLLM. -vLLM provides comprehensive benchmarking tools for performance testing and evaluation: +It's a living document, updated as new features and datasets become available. -- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing -- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations -- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development - -## Benchmark CLI - -This section guides you through running benchmark tests with the extensive -datasets supported on vLLM. It's a living document, updated as new features and datasets -become available. - -### Dataset Overview +## Dataset Overview