22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import argparse
5+ import inspect
56import json
67import os
78import time
1415import torch
1516from ray .experimental .tqdm_ray import tqdm
1617
18+ from vllm .logger import init_logger
1719from vllm .model_executor .layers .fused_moe .config import (
1820 FusedMoEQuantConfig ,
19- _get_config_dtype_str ,
2021)
2122from vllm .model_executor .layers .fused_moe .fused_moe import *
2223from vllm .platforms import current_platform
2324from vllm .transformers_utils .config import get_config
2425from vllm .triton_utils import triton
2526from vllm .utils .argparse_utils import FlexibleArgumentParser
2627
28+ logger = init_logger (__name__ )
29+
2730FP8_DTYPE = current_platform .fp8_dtype ()
2831
2932
@@ -145,20 +148,15 @@ def run():
145148 else :
146149 quant_dtype = None
147150
148- quant_config = FusedMoEQuantConfig .make (
149- quant_dtype = quant_dtype ,
150- w1_scale = w1_scale ,
151- w2_scale = w2_scale ,
152- a1_scale = a1_scale ,
153- a2_scale = a2_scale ,
154- block_shape = block_quant_shape ,
151+ quant_config = make_quant_config_compatible (
152+ quant_dtype , w1_scale , w2_scale , a1_scale , a2_scale , block_quant_shape
155153 )
156154
157155 with override_config (config ):
158156 topk_weights , topk_ids , token_expert_indices = fused_topk (
159157 x , input_gating , topk , renormalize = not use_deep_gemm
160158 )
161- return fused_experts (
159+ return fused_experts_compatible (
162160 x ,
163161 w1 ,
164162 w2 ,
@@ -411,7 +409,7 @@ def benchmark(
411409 use_deep_gemm : bool = False ,
412410 ) -> tuple [dict [str , int ], float ]:
413411 current_platform .seed_everything (self .seed )
414- dtype_str = _get_config_dtype_str (
412+ dtype_str = _get_config_dtype_str_compatible (
415413 dtype , use_int8_w8a16 = use_int8_w8a16 , use_fp8_w8a8 = use_fp8_w8a8
416414 )
417415 # NOTE(woosuk): The current naming convention uses w2.shape[2], which
@@ -544,7 +542,7 @@ def save_configs(
544542 block_quant_shape : list [int ],
545543 save_dir : str ,
546544) -> None :
547- dtype_str = _get_config_dtype_str (
545+ dtype_str = _get_config_dtype_str_compatible (
548546 dtype , use_int8_w8a16 = use_int8_w8a16 , use_fp8_w8a8 = use_fp8_w8a8
549547 )
550548
@@ -568,6 +566,116 @@ def get_weight_block_size_safety(config, default_value=None):
568566 return default_value
569567
570568
569+ def _get_config_dtype_str_compatible (
570+ dtype : torch .dtype ,
571+ use_fp8_w8a8 : bool = False ,
572+ use_int8_w8a16 : bool = False ,
573+ use_int4_w4a16 : bool = False ,
574+ ocp_mx_scheme : str | None = None ,
575+ ) -> str | None :
576+ """
577+ Multi-level import fallback for _get_config_dtype_str function.
578+ Returns a string used to construct the filename for tuning info.
579+ Uses dynamic signature inspection to only pass supported parameters.
580+ """
581+ try :
582+ from vllm .model_executor .layers .fused_moe .config import (
583+ _get_config_dtype_str as _original_func ,
584+ )
585+
586+ sig = inspect .signature (_original_func )
587+ kwargs = {}
588+ if "use_fp8_w8a8" in sig .parameters :
589+ kwargs ["use_fp8_w8a8" ] = use_fp8_w8a8
590+ if "use_int8_w8a16" in sig .parameters :
591+ kwargs ["use_int8_w8a16" ] = use_int8_w8a16
592+ if "use_int4_w4a16" in sig .parameters :
593+ kwargs ["use_int4_w4a16" ] = use_int4_w4a16
594+ if "ocp_mx_scheme" in sig .parameters :
595+ kwargs ["ocp_mx_scheme" ] = ocp_mx_scheme
596+
597+ return _original_func (dtype , ** kwargs )
598+ except ImportError :
599+ # Fallback implementation that mimics the original function's logic
600+ if use_fp8_w8a8 :
601+ return "fp8_w8a8"
602+ elif use_int8_w8a16 :
603+ return "int8_w8a16"
604+ elif use_int4_w4a16 :
605+ return "int4_w4a16"
606+ elif ocp_mx_scheme is not None :
607+ # For OCP MX execution simulation
608+ return None
609+ elif dtype == torch .float :
610+ # avoiding cases where kernel fails when float32 MoE
611+ # use fp16/bfloat16 configs
612+ return "float32"
613+ return None
614+
615+
616+ def make_quant_config_compatible (
617+ quant_dtype , w1_scale , w2_scale , a1_scale , a2_scale , block_quant_shape
618+ ):
619+ """Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions."""
620+ if quant_dtype is None :
621+ return None
622+ param_combinations = [
623+ {
624+ "quant_dtype" : quant_dtype ,
625+ "w1_scale" : w1_scale ,
626+ "w2_scale" : w2_scale ,
627+ "a1_scale" : a1_scale ,
628+ "a2_scale" : a2_scale ,
629+ "block_shape" : block_quant_shape ,
630+ },
631+ {
632+ "quant_dtype" : quant_dtype ,
633+ "w1_scale" : w1_scale ,
634+ "w2_scale" : w2_scale ,
635+ "a1_scale" : a1_scale ,
636+ "a2_scale" : a2_scale ,
637+ },
638+ {
639+ "dtype" : quant_dtype ,
640+ "w1_scale" : w1_scale ,
641+ "w2_scale" : w2_scale ,
642+ "a1_scale" : a1_scale ,
643+ "a2_scale" : a2_scale ,
644+ },
645+ ]
646+ for params in param_combinations :
647+ filtered_params = {k : v for k , v in params .items () if v is not None }
648+ try :
649+ return FusedMoEQuantConfig .make (** filtered_params )
650+ except TypeError :
651+ continue
652+ raise TypeError (
653+ "Unable to create FusedMoEQuantConfig with any known parameter combination."
654+ )
655+
656+
657+ def fused_experts_compatible (
658+ x ,
659+ w1 ,
660+ w2 ,
661+ topk_weights ,
662+ topk_ids ,
663+ inplace = True ,
664+ quant_config = None ,
665+ allow_deep_gemm = False ,
666+ ):
667+ """Compatible wrapper for fused_experts function."""
668+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_experts
669+
670+ sig = inspect .signature (fused_experts )
671+ kwargs = {"inplace" : inplace }
672+ if "quant_config" in sig .parameters :
673+ kwargs ["quant_config" ] = quant_config
674+ if "allow_deep_gemm" in sig .parameters :
675+ kwargs ["allow_deep_gemm" ] = allow_deep_gemm
676+ return fused_experts (x , w1 , w2 , topk_weights , topk_ids , ** kwargs )
677+
678+
571679def main (args : argparse .Namespace ):
572680 print (args )
573681
@@ -665,7 +773,7 @@ def main(args: argparse.Namespace):
665773 if current_platform .is_rocm () and "HIP_VISIBLE_DEVICES" in os .environ :
666774 # Ray will set ROCR_VISIBLE_DEVICES for device visibility
667775 logger .warning (
668- "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
776+ "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility. "
669777 "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
670778 )
671779 val = os .environ ["HIP_VISIBLE_DEVICES" ]
0 commit comments