Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
runtime_common = [
"aiohttp", "decord", "fastapi",
"hf_transfer", "huggingface_hub", "interegular", "modelscope",
"orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.10"
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2",
"torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar==0.1.10", "ninja", "transformers==4.48.3"
]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1",
"flashinfer==0.1.6"
"sgl-kernel>=0.0.3.post6", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"flashinfer_python>=0.2.1.post2",
"outlines>=0.0.44,<=0.1.11",
]

# HIP (Heterogeneous-computing Interface for Portability) for AMD
Expand Down
36 changes: 32 additions & 4 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import orjson
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -82,6 +83,7 @@ def fused_moe_kernel(
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
use_int8_w8a8: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Expand All @@ -104,6 +106,7 @@ def fused_moe_kernel(
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.

This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
Expand Down Expand Up @@ -165,6 +168,16 @@ def fused_moe_kernel(
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)

if use_int8_w8a8:
# Load per-column scale for weights
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
Expand Down Expand Up @@ -221,6 +234,8 @@ def fused_moe_kernel(
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
elif use_int8_w8a8:
accumulator = (accumulator * a_scale[:, None] * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
Expand Down Expand Up @@ -473,6 +488,7 @@ def invoke_fused_moe_kernel(
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int8_w8a8: bool,
block_shape: Optional[List[int]] = None,
) -> None:
assert topk_weights.stride(1) == 1
Expand All @@ -493,6 +509,8 @@ def invoke_fused_moe_kernel(
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16:
assert B_scale is not None
elif use_int8_w8a8:
A, A_scale = per_token_quant_int8(A)
else:
assert A_scale is None
assert B_scale is None
Expand All @@ -507,7 +525,6 @@ def invoke_fused_moe_kernel(
even_Ks = True
else:
even_Ks = False

fused_moe_kernel[grid](
A,
B,
Expand Down Expand Up @@ -541,6 +558,7 @@ def invoke_fused_moe_kernel(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int8_w8a8=use_int8_w8a8,
even_Ks=even_Ks,
**config,
)
Expand Down Expand Up @@ -714,6 +732,7 @@ def inplace_fused_experts(
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -730,6 +749,7 @@ def inplace_fused_experts(
activation,
use_fp8_w8a8,
use_int8_w8a16,
use_int8_w8a8,
w1_scale,
w2_scale,
a1_scale,
Expand All @@ -747,6 +767,7 @@ def inplace_fused_experts_fake(
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -773,6 +794,7 @@ def outplace_fused_experts(
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -789,6 +811,7 @@ def outplace_fused_experts(
activation,
use_fp8_w8a8,
use_int8_w8a16,
use_int8_w8a8,
w1_scale,
w2_scale,
a1_scale,
Expand Down Expand Up @@ -833,6 +856,7 @@ def fused_experts(
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -849,6 +873,7 @@ def fused_experts(
activation,
use_fp8_w8a8,
use_int8_w8a16,
use_int8_w8a8,
w1_scale,
w2_scale,
a1_scale,
Expand All @@ -866,6 +891,7 @@ def fused_experts(
activation,
use_fp8_w8a8,
use_int8_w8a16,
use_int8_w8a8,
w1_scale,
w2_scale,
a1_scale,
Expand All @@ -884,6 +910,7 @@ def fused_experts_impl(
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -975,7 +1002,6 @@ def fused_experts_impl(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], E
)

invoke_fused_moe_kernel(
curr_hidden_states,
w1,
Expand All @@ -993,16 +1019,15 @@ def fused_experts_impl(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int8_w8a8=use_int8_w8a8,
block_shape=block_shape,
)

if activation == "silu":
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
elif activation == "gelu":
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported activation: {activation=}")

invoke_fused_moe_kernel(
intermediate_cache2,
w2,
Expand All @@ -1020,6 +1045,7 @@ def fused_experts_impl(
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int8_w8a8=use_int8_w8a8,
block_shape=block_shape,
)

Expand Down Expand Up @@ -1064,6 +1090,7 @@ def fused_moe(
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1130,6 +1157,7 @@ def fused_moe(
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int8_w8a8=use_int8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
Expand Down
Loading