Skip to content

Commit

Permalink
[Model] Use optimized group gemm for Mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Mar 20, 2024
1 parent 39d0865 commit 5cfc06b
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/mlc_llm/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int:
target=args.target,
flashinfer=args.opt.flashinfer,
faster_transformer=args.opt.faster_transformer,
cutlass=args.opt.cutlass,
)
# Step 1. Create the quantized model
logger.info("Creating model from: %s", args.config)
Expand Down
14 changes: 14 additions & 0 deletions python/mlc_llm/interface/compiler_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class OptimizationFlags:
cublas_gemm: bool = False
faster_transformer: bool = False
cudagraph: bool = False
cutlass: bool = False

def __repr__(self) -> str:
out = StringIO()
print(f"flashinfer={int(self.flashinfer)}", file=out, end="")
print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="")
print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="")
print(f";cudagraph={int(self.cudagraph)}", file=out, end="")
print(f";cutlass={int(self.cutlass)}", file=out, end="")
return out.getvalue().rstrip()

@staticmethod
Expand All @@ -49,12 +51,14 @@ def boolean(value: str) -> bool:
parser.add_argument("--cublas_gemm", type=boolean, default=False)
parser.add_argument("--faster_transformer", type=boolean, default=False)
parser.add_argument("--cudagraph", type=boolean, default=False)
parser.add_argument("--cutlass", type=boolean, default=False)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return OptimizationFlags(
flashinfer=results.flashinfer,
cublas_gemm=results.cublas_gemm,
faster_transformer=results.faster_transformer,
cudagraph=results.cudagraph,
cutlass=results.cutlass,
)

def update(self, target, quantization) -> None:
Expand Down Expand Up @@ -90,9 +94,16 @@ def _faster_transformer(target) -> bool:
return False
return self.faster_transformer

def _cutlass(target) -> bool:
"""correct cutlass flag"""
if not target.kind.name == "cuda":
return False
return self.cutlass

self.flashinfer = _flashinfer(target)
self.cublas_gemm = _cublas_gemm(target, quantization)
self.faster_transformer = _faster_transformer(target)
self.cutlass = _cutlass(target)


@dataclasses.dataclass
Expand Down Expand Up @@ -148,17 +159,20 @@ def from_str(source: str) -> "ModelConfigOverride":
cublas_gemm=True,
faster_transformer=True,
cudagraph=False,
cutlass=True,
),
"O2": OptimizationFlags(
flashinfer=True,
cublas_gemm=True,
faster_transformer=True,
cudagraph=False,
cutlass=True,
),
"O3": OptimizationFlags(
flashinfer=True,
cublas_gemm=True,
faster_transformer=True,
cudagraph=True,
cutlass=True,
),
}
4 changes: 3 additions & 1 deletion python/mlc_llm/model/mixtral/mixtral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _expert_forward(x: Tensor, indptr: Tensor):
# expert_weights: [num_tokens, experts_per_tok]
# expert_indices: [num_tokens, experts_per_tok]
expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(gate, experts_per_tok)
use_ft = op_ext.get_store().faster_transformer and self.dtype == "float16"
use_ft = (
op_ext.get_store().cutlass_group_gemm or op_ext.get_store().faster_transformer
) and self.dtype == "float16"
if num_tokens == 1:
# x: [num_tokens * experts_per_tok, hidden_size]
x = _expert_forward(x, expert_indices)
Expand Down
4 changes: 3 additions & 1 deletion python/mlc_llm/nn/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor

from mlc_llm.op import extern, ft_gemm, moe_matmul
from mlc_llm.op import cutlass, extern, ft_gemm, moe_matmul


class MixtralExperts(nn.Module):
Expand All @@ -21,6 +21,8 @@ def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,mi
assert indptr.shape[0] == 1
return moe_matmul.gemv(x, self.weight, indptr)
assert indptr.ndim == 1
if extern.get_store().cutlass_group_gemm and self.dtype == "float16":
return cutlass.group_gemm(x, self.weight, indptr)
if extern.get_store().faster_transformer and self.dtype == "float16":
return ft_gemm.faster_transformer_moe_gemm(x, self.weight, indptr)
return moe_matmul.group_gemm(x, self.weight, indptr)
76 changes: 76 additions & 0 deletions python/mlc_llm/op/cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Operators enabled by external modules."""

from typing import Optional

from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import op


def group_gemm(
x: nn.Tensor,
weight: nn.Tensor,
indptr: nn.Tensor,
scale: Optional[nn.Tensor] = None,
weight_dtype: Optional[str] = None,
out_dtype: Optional[str] = None,
): # pylint: disable=too-many-arguments
"""
Cutlass group gemm operator.
Parameters
----------
x : nn.Tensor
The input tensor, with shape of [m, k].
weight : nn.Tensor
The weight tensor, with shape of [num_groups, n, k].
indptr : nn.Tensor
The indptr tensor, with shape of [num_groups].
scale : Optional[nn.Tensor]
The scale tensor, with shape of [1].
weight_dtype: Optional[str]
The data type of the weight tensor.
out_dtype: Optional[str]
The data type of the output tensor.
Returns
-------
nn.Tensor
The output tensor, with shape of [m, n].
"""
assert x.ndim == 2
assert weight.ndim == 3
assert indptr.ndim == 1
assert weight.shape[2] == x.shape[1]
assert weight.shape[0] == indptr.shape[0]
assert indptr.dtype == "int64"
out_dtype = out_dtype if out_dtype else x.dtype
weight_dtype = weight_dtype if weight_dtype else weight.dtype

if x.dtype == "e5m2_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16":
func_name = "cutlass.group_gemm_e5m2_e5m2_fp16"
elif x.dtype == "e4m3_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16":
func_name = "cutlass.group_gemm_e4m3_e5m2_fp16"
elif x.dtype == "e4m3_float8" and weight.dtype == "e4m3_float8" and out_dtype == "float16":
func_name = "cutlass.group_gemm_e4m3_e4m3_fp16"
elif x.dtype == "float16" and weight.dtype == "float16" and out_dtype == "float16":
func_name = "cutlass.group_gemm_fp16_sm90"
else:
raise NotImplementedError(
f"Unsupported data type: x={x.dtype}, weight={weight.dtype}, out={out_dtype}"
)

if "float8" in x.dtype:
assert scale is not None, "scale is required for float8 input"

workspace = op.empty((4096 * 1024,), dtype="uint8", name="workspace")

return op.extern(
func_name,
args=[x, weight, indptr, workspace] + ([scale] if scale is not None else []),
out=nn.Tensor.placeholder((x.shape[0], weight.shape[1]), dtype=out_dtype),
)
6 changes: 5 additions & 1 deletion python/mlc_llm/op/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,24 @@ class ExternModuleStore:
target: Optional[Target] = None
flashinfer: bool = False
faster_transformer: bool = False
cutlass_group_gemm: bool = False


STORE: ExternModuleStore = ExternModuleStore()
"""Singleton of `ExternModuleStore`."""


def enable(target: Target, flashinfer: bool, faster_transformer: bool) -> None:
def enable(target: Target, flashinfer: bool, faster_transformer: bool, cutlass: bool) -> None:
"""Enable external modules. It should be called before any compilation happens."""
global STORE # pylint: disable=global-statement
STORE = ExternModuleStore(
configured=False,
target=target,
flashinfer=flashinfer,
faster_transformer=faster_transformer,
cutlass_group_gemm=cutlass
and target.kind.name == "cuda"
and target.attrs.get("arch", "") == "sm_90a",
)


Expand Down

0 comments on commit 5cfc06b

Please sign in to comment.