Skip to content

Commit 95d680b

Browse files
authored
[Bugfix][IPEX] Add VLLM_CPU_MOE_PREPACK to allow disabling MoE prepack when CPU does not support it (#14681)
Signed-off-by: Thien Tran <[email protected]>
1 parent fb4c7f8 commit 95d680b

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

docs/source/getting_started/installation/cpu.md

+1
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ vLLM CPU backend supports the following vLLM features:
195195

196196
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
197197
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
198+
- `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
198199

199200
## Performance tips
200201

vllm/envs.py

+7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
VLLM_PP_LAYER_PARTITION: Optional[str] = None
4141
VLLM_CPU_KVCACHE_SPACE: int = 0
4242
VLLM_CPU_OMP_THREADS_BIND: str = ""
43+
VLLM_CPU_MOE_PREPACK: bool = True
4344
VLLM_OPENVINO_DEVICE: str = "CPU"
4445
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
4546
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
@@ -349,6 +350,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
349350
"VLLM_CPU_OMP_THREADS_BIND":
350351
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
351352

353+
# (CPU backend only) whether to use prepack for MoE layer. This will be
354+
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
355+
# need to set this to "0" (False).
356+
"VLLM_CPU_MOE_PREPACK":
357+
lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
358+
352359
# OpenVINO device selection
353360
# default is CPU
354361
"VLLM_OPENVINO_DEVICE":

vllm/model_executor/layers/fused_moe/layer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from torch.nn.parameter import UninitializedParameter
99

10+
from vllm import envs
1011
from vllm.config import get_current_vllm_config
1112
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
1213
get_tensor_model_parallel_world_size,
@@ -104,7 +105,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
104105
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
105106
layer.w13_weight,
106107
layer.w2_weight,
107-
use_prepack=True,
108+
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
108109
)
109110
else:
110111
raise NotImplementedError("CPU MOE only supports x86 arch.")

0 commit comments

Comments
 (0)