Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 120 additions & 12 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import inspect
import json
import os
import time
Expand All @@ -14,16 +15,18 @@
import torch
from ray.experimental.tqdm_ray import tqdm

from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser

logger = init_logger(__name__)

FP8_DTYPE = current_platform.fp8_dtype()


Expand Down Expand Up @@ -145,20 +148,15 @@ def run():
else:
quant_dtype = None

quant_config = FusedMoEQuantConfig.make(
quant_dtype=quant_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
quant_config = make_quant_config_compatible(
quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape
)

with override_config(config):
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
return fused_experts(
return fused_experts_compatible(
x,
w1,
w2,
Expand Down Expand Up @@ -411,7 +409,7 @@ def benchmark(
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
dtype_str = _get_config_dtype_str(
dtype_str = _get_config_dtype_str_compatible(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
Expand Down Expand Up @@ -544,7 +542,7 @@ def save_configs(
block_quant_shape: list[int],
save_dir: str,
) -> None:
dtype_str = _get_config_dtype_str(
dtype_str = _get_config_dtype_str_compatible(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)

Expand All @@ -568,6 +566,116 @@ def get_weight_block_size_safety(config, default_value=None):
return default_value


def _get_config_dtype_str_compatible(
dtype: torch.dtype,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
) -> str | None:
"""
Multi-level import fallback for _get_config_dtype_str function.
Returns a string used to construct the filename for tuning info.
Uses dynamic signature inspection to only pass supported parameters.
"""
try:
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str as _original_func,
)

sig = inspect.signature(_original_func)
kwargs = {}
if "use_fp8_w8a8" in sig.parameters:
kwargs["use_fp8_w8a8"] = use_fp8_w8a8
if "use_int8_w8a16" in sig.parameters:
kwargs["use_int8_w8a16"] = use_int8_w8a16
if "use_int4_w4a16" in sig.parameters:
kwargs["use_int4_w4a16"] = use_int4_w4a16
if "ocp_mx_scheme" in sig.parameters:
kwargs["ocp_mx_scheme"] = ocp_mx_scheme

return _original_func(dtype, **kwargs)
except ImportError:
# Fallback implementation that mimics the original function's logic
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif ocp_mx_scheme is not None:
# For OCP MX execution simulation
return None
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None


def make_quant_config_compatible(
quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape
):
"""Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions."""
if quant_dtype is None:
return None
param_combinations = [
{
"quant_dtype": quant_dtype,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"a1_scale": a1_scale,
"a2_scale": a2_scale,
"block_shape": block_quant_shape,
},
{
"quant_dtype": quant_dtype,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"a1_scale": a1_scale,
"a2_scale": a2_scale,
},
{
"dtype": quant_dtype,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"a1_scale": a1_scale,
"a2_scale": a2_scale,
},
]
for params in param_combinations:
filtered_params = {k: v for k, v in params.items() if v is not None}
try:
return FusedMoEQuantConfig.make(**filtered_params)
except TypeError:
continue
raise TypeError(
"Unable to create FusedMoEQuantConfig with any known parameter combination."
)


def fused_experts_compatible(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
quant_config=None,
allow_deep_gemm=False,
):
"""Compatible wrapper for fused_experts function."""
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

sig = inspect.signature(fused_experts)
kwargs = {"inplace": inplace}
if "quant_config" in sig.parameters:
kwargs["quant_config"] = quant_config
if "allow_deep_gemm" in sig.parameters:
kwargs["allow_deep_gemm"] = allow_deep_gemm
return fused_experts(x, w1, w2, topk_weights, topk_ids, **kwargs)


def main(args: argparse.Namespace):
print(args)

Expand Down Expand Up @@ -665,7 +773,7 @@ def main(args: argparse.Namespace):
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility. "
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val = os.environ["HIP_VISIBLE_DEVICES"]
Expand Down