22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import argparse
5+ import inspect
56import json
67import os
78import time
1617
1718from vllm .model_executor .layers .fused_moe .config import (
1819 FusedMoEQuantConfig ,
19- _get_config_dtype_str ,
2020)
2121from vllm .model_executor .layers .fused_moe .fused_moe import *
2222from vllm .platforms import current_platform
@@ -145,20 +145,15 @@ def run():
145145 else :
146146 quant_dtype = None
147147
148- quant_config = FusedMoEQuantConfig .make (
149- quant_dtype = quant_dtype ,
150- w1_scale = w1_scale ,
151- w2_scale = w2_scale ,
152- a1_scale = a1_scale ,
153- a2_scale = a2_scale ,
154- block_shape = block_quant_shape ,
148+ quant_config = make_quant_config_compatible (
149+ quant_dtype , w1_scale , w2_scale , a1_scale , a2_scale , block_quant_shape
155150 )
156151
157152 with override_config (config ):
158153 topk_weights , topk_ids , token_expert_indices = fused_topk (
159154 x , input_gating , topk , renormalize = not use_deep_gemm
160155 )
161- return fused_experts (
156+ return fused_experts_compatible (
162157 x ,
163158 w1 ,
164159 w2 ,
@@ -411,7 +406,7 @@ def benchmark(
411406 use_deep_gemm : bool = False ,
412407 ) -> tuple [dict [str , int ], float ]:
413408 current_platform .seed_everything (self .seed )
414- dtype_str = _get_config_dtype_str (
409+ dtype_str = _get_config_dtype_str_compatible (
415410 dtype , use_int8_w8a16 = use_int8_w8a16 , use_fp8_w8a8 = use_fp8_w8a8
416411 )
417412 # NOTE(woosuk): The current naming convention uses w2.shape[2], which
@@ -544,7 +539,7 @@ def save_configs(
544539 block_quant_shape : list [int ],
545540 save_dir : str ,
546541) -> None :
547- dtype_str = _get_config_dtype_str (
542+ dtype_str = _get_config_dtype_str_compatible (
548543 dtype , use_int8_w8a16 = use_int8_w8a16 , use_fp8_w8a8 = use_fp8_w8a8
549544 )
550545
@@ -568,6 +563,116 @@ def get_weight_block_size_safety(config, default_value=None):
568563 return default_value
569564
570565
566+ def _get_config_dtype_str_compatible (
567+ dtype : torch .dtype ,
568+ use_fp8_w8a8 : bool = False ,
569+ use_int8_w8a16 : bool = False ,
570+ use_int4_w4a16 : bool = False ,
571+ ocp_mx_scheme : str | None = None ,
572+ ) -> str | None :
573+ """
574+ Multi-level import fallback for _get_config_dtype_str function.
575+ Returns a string used to construct the filename for tuning info.
576+ Uses dynamic signature inspection to only pass supported parameters.
577+ """
578+ try :
579+ from vllm .model_executor .layers .fused_moe .config import (
580+ _get_config_dtype_str as _original_func ,
581+ )
582+
583+ sig = inspect .signature (_original_func )
584+ kwargs = {}
585+ if "use_fp8_w8a8" in sig .parameters :
586+ kwargs ["use_fp8_w8a8" ] = use_fp8_w8a8
587+ if "use_int8_w8a16" in sig .parameters :
588+ kwargs ["use_int8_w8a16" ] = use_int8_w8a16
589+ if "use_int4_w4a16" in sig .parameters :
590+ kwargs ["use_int4_w4a16" ] = use_int4_w4a16
591+ if "ocp_mx_scheme" in sig .parameters :
592+ kwargs ["ocp_mx_scheme" ] = ocp_mx_scheme
593+
594+ return _original_func (dtype , ** kwargs )
595+ except ImportError :
596+ # Fallback implementation that mimics the original function's logic
597+ if use_fp8_w8a8 :
598+ return "fp8_w8a8"
599+ elif use_int8_w8a16 :
600+ return "int8_w8a16"
601+ elif use_int4_w4a16 :
602+ return "int4_w4a16"
603+ elif ocp_mx_scheme is not None :
604+ # For OCP MX execution simulation
605+ return None
606+ elif dtype == torch .float :
607+ # avoiding cases where kernel fails when float32 MoE
608+ # use fp16/bfloat16 configs
609+ return "float32"
610+ return None
611+
612+
613+ def make_quant_config_compatible (
614+ quant_dtype , w1_scale , w2_scale , a1_scale , a2_scale , block_quant_shape
615+ ):
616+ """Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions."""
617+ if quant_dtype is None :
618+ return None
619+ param_combinations = [
620+ {
621+ "quant_dtype" : quant_dtype ,
622+ "w1_scale" : w1_scale ,
623+ "w2_scale" : w2_scale ,
624+ "a1_scale" : a1_scale ,
625+ "a2_scale" : a2_scale ,
626+ "block_shape" : block_quant_shape ,
627+ },
628+ {
629+ "quant_dtype" : quant_dtype ,
630+ "w1_scale" : w1_scale ,
631+ "w2_scale" : w2_scale ,
632+ "a1_scale" : a1_scale ,
633+ "a2_scale" : a2_scale ,
634+ },
635+ {
636+ "dtype" : quant_dtype ,
637+ "w1_scale" : w1_scale ,
638+ "w2_scale" : w2_scale ,
639+ "a1_scale" : a1_scale ,
640+ "a2_scale" : a2_scale ,
641+ },
642+ ]
643+ for params in param_combinations :
644+ filtered_params = {k : v for k , v in params .items () if v is not None }
645+ try :
646+ return FusedMoEQuantConfig .make (** filtered_params )
647+ except TypeError :
648+ continue
649+ raise TypeError (
650+ "Unable to create FusedMoEQuantConfig with any known parameter combination."
651+ )
652+
653+
654+ def fused_experts_compatible (
655+ x ,
656+ w1 ,
657+ w2 ,
658+ topk_weights ,
659+ topk_ids ,
660+ inplace = True ,
661+ quant_config = None ,
662+ allow_deep_gemm = False ,
663+ ):
664+ """Compatible wrapper for fused_experts function."""
665+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_experts
666+
667+ sig = inspect .signature (fused_experts )
668+ kwargs = {"inplace" : inplace }
669+ if "quant_config" in sig .parameters :
670+ kwargs ["quant_config" ] = quant_config
671+ if "allow_deep_gemm" in sig .parameters :
672+ kwargs ["allow_deep_gemm" ] = allow_deep_gemm
673+ return fused_experts (x , w1 , w2 , topk_weights , topk_ids , ** kwargs )
674+
675+
571676def main (args : argparse .Namespace ):
572677 print (args )
573678
@@ -664,8 +769,8 @@ def main(args: argparse.Namespace):
664769
665770 if current_platform .is_rocm () and "HIP_VISIBLE_DEVICES" in os .environ :
666771 # Ray will set ROCR_VISIBLE_DEVICES for device visibility
667- logger . warning (
668- "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
772+ print (
773+ "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility. "
669774 "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
670775 )
671776 val = os .environ ["HIP_VISIBLE_DEVICES" ]
0 commit comments