Skip to content

[Bugfix][IPEX] Add VLLM_CPU_MOE_PREPACK to allow disabling MoE prepack when CPU does not support it #14681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/getting_started/installation/cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ vLLM CPU backend supports the following vLLM features:

- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
- `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).

## Performance tips

Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_MOE_PREPACK: bool = True
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
Expand Down Expand Up @@ -349,6 +350,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),

# (CPU backend only) whether to use prepack for MoE layer. This will be
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
# need to set this to "0" (False).
"VLLM_CPU_MOE_PREPACK":
lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),

# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch.nn.parameter import UninitializedParameter

from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -104,7 +105,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
)
else:
raise NotImplementedError("CPU MOE only supports x86 arch.")
Expand Down