diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index bc6cf83bc21f..1aadb667e5b4 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import inspect import json import os import time @@ -14,9 +15,9 @@ import torch from ray.experimental.tqdm_ray import tqdm +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform @@ -24,6 +25,8 @@ from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser +logger = init_logger(__name__) + FP8_DTYPE = current_platform.fp8_dtype() @@ -145,20 +148,15 @@ def run(): else: quant_dtype = None - quant_config = FusedMoEQuantConfig.make( - quant_dtype=quant_dtype, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, + quant_config = make_quant_config_compatible( + quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape ) with override_config(config): topk_weights, topk_ids, token_expert_indices = fused_topk( x, input_gating, topk, renormalize=not use_deep_gemm ) - return fused_experts( + return fused_experts_compatible( x, w1, w2, @@ -411,7 +409,7 @@ def benchmark( use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = _get_config_dtype_str( + dtype_str = _get_config_dtype_str_compatible( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which @@ -544,7 +542,7 @@ def save_configs( block_quant_shape: list[int], save_dir: str, ) -> None: - dtype_str = _get_config_dtype_str( + dtype_str = _get_config_dtype_str_compatible( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) @@ -568,6 +566,116 @@ def get_weight_block_size_safety(config, default_value=None): return default_value +def _get_config_dtype_str_compatible( + dtype: torch.dtype, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: str | None = None, +) -> str | None: + """ + Multi-level import fallback for _get_config_dtype_str function. + Returns a string used to construct the filename for tuning info. + Uses dynamic signature inspection to only pass supported parameters. + """ + try: + from vllm.model_executor.layers.fused_moe.config import ( + _get_config_dtype_str as _original_func, + ) + + sig = inspect.signature(_original_func) + kwargs = {} + if "use_fp8_w8a8" in sig.parameters: + kwargs["use_fp8_w8a8"] = use_fp8_w8a8 + if "use_int8_w8a16" in sig.parameters: + kwargs["use_int8_w8a16"] = use_int8_w8a16 + if "use_int4_w4a16" in sig.parameters: + kwargs["use_int4_w4a16"] = use_int4_w4a16 + if "ocp_mx_scheme" in sig.parameters: + kwargs["ocp_mx_scheme"] = ocp_mx_scheme + + return _original_func(dtype, **kwargs) + except ImportError: + # Fallback implementation that mimics the original function's logic + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" + elif ocp_mx_scheme is not None: + # For OCP MX execution simulation + return None + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def make_quant_config_compatible( + quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape +): + """Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions.""" + if quant_dtype is None: + return None + param_combinations = [ + { + "quant_dtype": quant_dtype, + "w1_scale": w1_scale, + "w2_scale": w2_scale, + "a1_scale": a1_scale, + "a2_scale": a2_scale, + "block_shape": block_quant_shape, + }, + { + "quant_dtype": quant_dtype, + "w1_scale": w1_scale, + "w2_scale": w2_scale, + "a1_scale": a1_scale, + "a2_scale": a2_scale, + }, + { + "dtype": quant_dtype, + "w1_scale": w1_scale, + "w2_scale": w2_scale, + "a1_scale": a1_scale, + "a2_scale": a2_scale, + }, + ] + for params in param_combinations: + filtered_params = {k: v for k, v in params.items() if v is not None} + try: + return FusedMoEQuantConfig.make(**filtered_params) + except TypeError: + continue + raise TypeError( + "Unable to create FusedMoEQuantConfig with any known parameter combination." + ) + + +def fused_experts_compatible( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=None, + allow_deep_gemm=False, +): + """Compatible wrapper for fused_experts function.""" + from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts + + sig = inspect.signature(fused_experts) + kwargs = {"inplace": inplace} + if "quant_config" in sig.parameters: + kwargs["quant_config"] = quant_config + if "allow_deep_gemm" in sig.parameters: + kwargs["allow_deep_gemm"] = allow_deep_gemm + return fused_experts(x, w1, w2, topk_weights, topk_ids, **kwargs) + + def main(args: argparse.Namespace): print(args) @@ -665,7 +773,7 @@ def main(args: argparse.Namespace): if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ: # Ray will set ROCR_VISIBLE_DEVICES for device visibility logger.warning( - "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." + "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility. " "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." ) val = os.environ["HIP_VISIBLE_DEVICES"]