Skip to content

Conversation

@massif-01
Copy link

@massif-01 massif-01 commented Oct 27, 2025

This PR introduces compatibility fixes to benchmarks/kernels/benchmark_moe.py to support multiple vLLM versions and prevent runtime import/parameter errors. These errors occurred while I was optimizing Qwen3- Coder-30B-A3B-Instruct-FP8:

ImportError: cannot import name '_get_config_dtype_str'

Added a multi-level import fallback that searches possible module locations and class methods for _get_config_dtype_str and provides a fallback implementation when unavailable.
TypeError: FusedMoEQuantConfig.make() parameter incompatibility

Implemented make_quant_config_compatible() which tries multiple parameter combinations (including quant_dtype, dtype, with/without block_quant_shape) to create FusedMoEQuantConfig across versions.
TypeError: fused_experts() parameter incompatibility

Implemented fused_experts_compatible() which inspects fused_experts signature and only passes supported parameters (quant_config, allow_deep_gemm, etc.).
Signed-off-by: massif-01 [email protected]

@mergify mergify bot added the performance Performance-related issues label Oct 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several compatibility wrappers to benchmark_moe.py to handle API differences across vLLM versions, which is a solid approach. The changes are well-motivated and address specific ImportError and TypeError issues encountered during benchmarking. My main feedback is to improve the robustness of one of the new wrapper functions to avoid masking potential bugs, making it consistent with other wrappers in this PR.

Comment on lines 566 to 613
def _get_config_dtype_str_compatible(
dtype: torch.dtype,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
) -> str | None:
"""
Multi-level import fallback for _get_config_dtype_str function.
Returns a string used to construct the filename for tuning info.
"""
try:
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str as _original_func,
)

return _original_func(
dtype,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
ocp_mx_scheme=ocp_mx_scheme,
)
except (ImportError, TypeError):
# Fallback implementation that mimics the original function's logic
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif ocp_mx_scheme is not None:
# For OCP MX execution simulation
return None
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching TypeError here is too broad and can mask other potential bugs. For instance, if _original_func is called with an incorrect type for one of its arguments, the TypeError would be silently caught, and the code would fall back to a potentially incorrect logic. A more robust approach, consistent with fused_experts_compatible in this same PR, is to use inspect.signature to dynamically build the keyword arguments and only catch ImportError. This ensures that unexpected TypeErrors are surfaced for debugging.

def _get_config_dtype_str_compatible(
    dtype: torch.dtype,
    use_fp8_w8a8: bool = False,
    use_int8_w8a16: bool = False,
    use_int4_w4a16: bool = False,
    ocp_mx_scheme: str | None = None,
) -> str | None:
    """
    Multi-level import fallback for _get_config_dtype_str function.
    Returns a string used to construct the filename for tuning info.
    """
    try:
        from vllm.model_executor.layers.fused_moe.config import (
            _get_config_dtype_str as _original_func,
        )

        sig = inspect.signature(_original_func)
        kwargs = {}
        if "use_fp8_w8a8" in sig.parameters:
            kwargs["use_fp8_w8a8"] = use_fp8_w8a8
        if "use_int8_w8a16" in sig.parameters:
            kwargs["use_int8_w8a16"] = use_int8_w8a16
        if "use_int4_w4a16" in sig.parameters:
            kwargs["use_int4_w4a16"] = use_int4_w4a16
        if "ocp_mx_scheme" in sig.parameters:
            kwargs["ocp_mx_scheme"] = ocp_mx_scheme

        return _original_func(dtype, **kwargs)
    except ImportError:
        # Fallback implementation that mimics the original function's logic
        if use_fp8_w8a8:
            return "fp8_w8a8"
        elif use_int8_w8a16:
            return "int8_w8a16"
        elif use_int4_w4a16:
            return "int4_w4a16"
        elif ocp_mx_scheme is not None:
            # For OCP MX execution simulation
            return None
        elif dtype == torch.float:
            # avoiding cases where kernel fails when float32 MoE
            # use fp16/bfloat16 configs
            return "float32"
        return None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've improved the code.

@massif-01 massif-01 force-pushed the fix/benchmark-moe-compatibility branch from 2527b65 to c27760b Compare October 27, 2025 01:43
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use logger here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've revised it. Please review.

@massif-01 massif-01 force-pushed the fix/benchmark-moe-compatibility branch from c27760b to 2b52444 Compare October 27, 2025 03:15
@massif-01 massif-01 requested a review from ZJY0516 October 27, 2025 03:28
This PR fixes compatibility issues in benchmarks/kernels/benchmark_moe.py
to support multiple vLLM versions:

1. Fix FusedMoEQuantConfig parameter naming
   - Use correct 'block_shape' parameter name instead of 'block_quant_shape'
   - Prevents TypeError: unexpected keyword argument

2. Add compatible wrapper for _get_config_dtype_str
   - Uses inspect.signature for dynamic parameter detection
   - Only catches ImportError (not TypeError) to surface bugs properly
   - Robust fallback implementation

3. Add compatible wrapper for fused_experts
   - Dynamic signature inspection to pass only supported parameters
   - Handles quant_config and allow_deep_gemm parameters gracefully

These fixes resolve runtime errors that occurred when running MoE
benchmarks with different vLLM versions, particularly when using
block quantization.

Tested with:
- Python syntax: passed
- Ruff linter: all checks passed
- Line length: all lines < 88 chars

Addresses code review feedback from gemini-code-assist bot.

Signed-off-by: Alfred <[email protected]>
@massif-01 massif-01 force-pushed the fix/benchmark-moe-compatibility branch from 2b52444 to b3d1e82 Compare October 27, 2025 04:08
@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 27, 2025

I am not sure do we really need this compatibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants