diff --git a/python/pyproject.toml b/python/pyproject.toml index 11c984f82d7..e0c0adc1c5c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,16 +19,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] runtime_common = [ "aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "modelscope", - "orjson", "outlines>=0.0.44,<0.1.0", - "packaging", "pillow", "prometheus-client>=0.20.0", - "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.10" + "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", + "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", + "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar==0.1.10", "ninja", "transformers==4.48.3" ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", - "flashinfer==0.1.6" + "sgl-kernel>=0.0.3.post6", "torch", "vllm>=0.6.4.post1,<=0.7.2", + "flashinfer_python>=0.2.1.post2", + "outlines>=0.0.44,<=0.1.11", ] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 32c8fcbb625..aca1cbce238 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -8,6 +8,7 @@ import os from typing import Any, Callable, Dict, List, Optional, Tuple +import orjson import torch import triton import triton.language as tl @@ -82,6 +83,7 @@ def fused_moe_kernel( compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + use_int8_w8a8: tl.constexpr, even_Ks: tl.constexpr, ): """ @@ -104,6 +106,7 @@ def fused_moe_kernel( - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by @@ -165,6 +168,16 @@ def fused_moe_kernel( a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) + if use_int8_w8a8: + # Load per-column scale for weights + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0) + # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -221,6 +234,8 @@ def fused_moe_kernel( accumulator = accumulator.to(compute_type) else: accumulator = (accumulator * a_scale * b_scale).to(compute_type) + elif use_int8_w8a8: + accumulator = (accumulator * a_scale[:, None] * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -473,6 +488,7 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int8_w8a8: bool, block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 @@ -493,6 +509,8 @@ def invoke_fused_moe_kernel( assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] elif use_int8_w8a16: assert B_scale is not None + elif use_int8_w8a8: + A, A_scale = per_token_quant_int8(A) else: assert A_scale is None assert B_scale is None @@ -507,7 +525,6 @@ def invoke_fused_moe_kernel( even_Ks = True else: even_Ks = False - fused_moe_kernel[grid]( A, B, @@ -541,6 +558,7 @@ def invoke_fused_moe_kernel( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int8_w8a8=use_int8_w8a8, even_Ks=even_Ks, **config, ) @@ -714,6 +732,7 @@ def inplace_fused_experts( activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int8_w8a8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, @@ -730,6 +749,7 @@ def inplace_fused_experts( activation, use_fp8_w8a8, use_int8_w8a16, + use_int8_w8a8, w1_scale, w2_scale, a1_scale, @@ -747,6 +767,7 @@ def inplace_fused_experts_fake( activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int8_w8a8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, @@ -773,6 +794,7 @@ def outplace_fused_experts( activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int8_w8a8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, @@ -789,6 +811,7 @@ def outplace_fused_experts( activation, use_fp8_w8a8, use_int8_w8a16, + use_int8_w8a8, w1_scale, w2_scale, a1_scale, @@ -833,6 +856,7 @@ def fused_experts( activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int8_w8a8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, @@ -849,6 +873,7 @@ def fused_experts( activation, use_fp8_w8a8, use_int8_w8a16, + use_int8_w8a8, w1_scale, w2_scale, a1_scale, @@ -866,6 +891,7 @@ def fused_experts( activation, use_fp8_w8a8, use_int8_w8a16, + use_int8_w8a8, w1_scale, w2_scale, a1_scale, @@ -884,6 +910,7 @@ def fused_experts_impl( activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int8_w8a8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, @@ -975,7 +1002,6 @@ def fused_experts_impl( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( curr_topk_ids, config["BLOCK_SIZE_M"], E ) - invoke_fused_moe_kernel( curr_hidden_states, w1, @@ -993,16 +1019,15 @@ def fused_experts_impl( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int8_w8a8=use_int8_w8a8, block_shape=block_shape, ) - if activation == "silu": ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) elif activation == "gelu": ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) else: raise ValueError(f"Unsupported activation: {activation=}") - invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -1020,6 +1045,7 @@ def fused_experts_impl( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int8_w8a8=use_int8_w8a8, block_shape=block_shape, ) @@ -1064,6 +1090,7 @@ def fused_moe( custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int8_w8a8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, @@ -1130,6 +1157,7 @@ def fused_moe( activation=activation, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int8_w8a8=use_int8_w8a8, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, diff --git a/python/sglang/srt/layers/quantization/int8.py b/python/sglang/srt/layers/quantization/int8.py new file mode 100644 index 00000000000..52ebdb9d7bc --- /dev/null +++ b/python/sglang/srt/layers/quantization/int8.py @@ -0,0 +1,583 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py + +import logging +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + apply_fp8_linear, + convert_to_channelwise, + cutlass_fp8_supported, + per_tensor_dequantize, + requantize_with_max_scale, +) + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import ( + BlockQuantScaleParameter, + apply_w8a8_block_fp8_linear, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.utils import ( + get_bool_env_var, + is_hip, + permute_weight, + print_warning_once, + set_weight_attrs, +) + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +is_hip_ = is_hip() + +logger = logging.getLogger(__name__) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + weight_block_size: List[int] = None, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + f"The block-wise quantization only supports fp8-serialized checkpoint for now." + ) + if len(weight_block_size) != 2: + raise ValueError( + f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." + ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + ) + self.weight_block_size = weight_block_size + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8MoEMethod: + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES - 修改为per-column量化 + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # 为每一列维护一个scale + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # 更新量化方法 + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else { + "quant_method": FusedMoeWeightScaleSupported.COLUMN.value + } # 改为COLUMN + ) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + padding_size, # Avoid circular import + ) + + # Block quant doesn't need to process weights after loading + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip_: + # activation_scheme: dynamic + w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w13_weight, + weight_scale=layer.w13_weight_scale_inv, + input_scale=None, + ) + w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w2_weight, + weight_scale=layer.w2_weight_scale_inv, + input_scale=None, + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w13_input_scale = None + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + layer.w2_input_scale = None + return + # If checkpoint is fp16 or bfloat16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If ROCm, use float8_e4m3fnuz instead (MI300x HW) + fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + if is_hip_: + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip_: + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + + if is_hip_: + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts + + # Expert selection + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + ) + + if is_hip_ and get_bool_env_var("CK_MOE"): + import ater + from ater.fused_moe import fused_experts_ck + + assert activation == "silu", f"{activation=} is not supported." + + return fused_experts_ck( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + else: + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) + + +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 3a02531e695..d97966ebc04 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -249,7 +249,7 @@ class TestW8A8BlockFP8FusedMoE(unittest.TestCase): E = [8, 24] TOP_KS = [2, 6] BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] - # BLOCK_SIZE = [[128, 128]] + BLOCK_SIZE = [[128, 128]] SEEDS = [0] @classmethod diff --git a/python/sglang/test/test_int8.py b/python/sglang/test/test_int8.py new file mode 100644 index 00000000000..b1c7429fe46 --- /dev/null +++ b/python/sglang/test/test_int8.py @@ -0,0 +1,168 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 + + +def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): + """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K,) + A = A.reshape(M, N) + + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + return C.reshape(origin_C_shape).to(output_dtype) + + +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): + """This function performs fused moe with per-column int8 quantization using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + # Calculate routing + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype + ) + # Activation function + act_out = SiluAndMul().forward_native(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = per_token_quant_int8(act_out) + + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) + # Apply routing weights and sum + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +class TestW8A8Int8FusedMoE(unittest.TestCase): + DTYPES = [torch.half] + M = [1, 33] + N = [128, 1024] + K = [256, 4096] + E = [8] + TOP_KS = [2, 6] + BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + # Initialize int8 quantization parameters + factor_for_scale = 1e-2 + int8_max = 127 + int8_min = -128 + + # Input tensor + # M * K + a = torch.randn((M, K), dtype=dtype) / 10 + + # Generate int8 weights + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 + w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 + w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + # Generate scale for each column (per-column quantization) + w1_s = torch.max(torch.abs(w1_fp32), dim=2)[0] * factor_for_scale # [E, 2*N] + w2_s = torch.max(torch.abs(w2_fp32), dim=2)[0] * factor_for_scale # [E, N] + score = torch.randn((M, E), dtype=dtype) + + with torch.inference_mode(): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=False, # Not using fp8 + use_int8_w8a16=False, # Not using int8-w8a16 + use_int8_w8a8=True, # Using int8-w8a8 + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) + + # Check results + print( + "diff: ", + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))), + ) + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.05 + ) + + def test_w8a8_int8_fused_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + block_size=params[5], + dtype=params[6], + seed=params[7], + ): + self._w8a8_int8_fused_moe(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)