diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index b6052a935a..5618ce3341 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -131,6 +131,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: target=args.target, flashinfer=args.opt.flashinfer, faster_transformer=args.opt.faster_transformer, + cutlass=args.opt.cutlass, ) # Step 1. Create the quantized model logger.info("Creating model from: %s", args.config) diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2c44efc10d..b4ff81e6eb 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -21,6 +21,7 @@ class OptimizationFlags: cublas_gemm: bool = False faster_transformer: bool = False cudagraph: bool = False + cutlass: bool = False def __repr__(self) -> str: out = StringIO() @@ -28,6 +29,7 @@ def __repr__(self) -> str: print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="") print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + print(f";cutlass={int(self.cutlass)}", file=out, end="") return out.getvalue().rstrip() @staticmethod @@ -49,12 +51,14 @@ def boolean(value: str) -> bool: parser.add_argument("--cublas_gemm", type=boolean, default=False) parser.add_argument("--faster_transformer", type=boolean, default=False) parser.add_argument("--cudagraph", type=boolean, default=False) + parser.add_argument("--cutlass", type=boolean, default=False) results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) return OptimizationFlags( flashinfer=results.flashinfer, cublas_gemm=results.cublas_gemm, faster_transformer=results.faster_transformer, cudagraph=results.cudagraph, + cutlass=results.cutlass, ) def update(self, target, quantization) -> None: @@ -90,9 +94,16 @@ def _faster_transformer(target) -> bool: return False return self.faster_transformer + def _cutlass(target) -> bool: + """correct cutlass flag""" + if not target.kind.name == "cuda": + return False + return self.cutlass + self.flashinfer = _flashinfer(target) self.cublas_gemm = _cublas_gemm(target, quantization) self.faster_transformer = _faster_transformer(target) + self.cutlass = _cutlass(target) @dataclasses.dataclass @@ -148,17 +159,20 @@ def from_str(source: str) -> "ModelConfigOverride": cublas_gemm=True, faster_transformer=True, cudagraph=False, + cutlass=True, ), "O2": OptimizationFlags( flashinfer=True, cublas_gemm=True, faster_transformer=True, cudagraph=False, + cutlass=True, ), "O3": OptimizationFlags( flashinfer=True, cublas_gemm=True, faster_transformer=True, cudagraph=True, + cutlass=True, ), } diff --git a/python/mlc_llm/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py index 3f41988788..ec8025f3dc 100644 --- a/python/mlc_llm/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -74,7 +74,9 @@ def _expert_forward(x: Tensor, indptr: Tensor): # expert_weights: [num_tokens, experts_per_tok] # expert_indices: [num_tokens, experts_per_tok] expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(gate, experts_per_tok) - use_ft = op_ext.get_store().faster_transformer and self.dtype == "float16" + use_ft = ( + op_ext.get_store().cutlass_group_gemm or op_ext.get_store().faster_transformer + ) and self.dtype == "float16" if num_tokens == 1: # x: [num_tokens * experts_per_tok, hidden_size] x = _expert_forward(x, expert_indices) diff --git a/python/mlc_llm/nn/expert.py b/python/mlc_llm/nn/expert.py index b6659d3d60..481b430baf 100644 --- a/python/mlc_llm/nn/expert.py +++ b/python/mlc_llm/nn/expert.py @@ -2,7 +2,7 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor -from mlc_llm.op import extern, ft_gemm, moe_matmul +from mlc_llm.op import cutlass, extern, ft_gemm, moe_matmul class MixtralExperts(nn.Module): @@ -21,6 +21,8 @@ def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,mi assert indptr.shape[0] == 1 return moe_matmul.gemv(x, self.weight, indptr) assert indptr.ndim == 1 + if extern.get_store().cutlass_group_gemm and self.dtype == "float16": + return cutlass.group_gemm(x, self.weight, indptr) if extern.get_store().faster_transformer and self.dtype == "float16": return ft_gemm.faster_transformer_moe_gemm(x, self.weight, indptr) return moe_matmul.group_gemm(x, self.weight, indptr) diff --git a/python/mlc_llm/op/cutlass.py b/python/mlc_llm/op/cutlass.py new file mode 100644 index 0000000000..275d61f20a --- /dev/null +++ b/python/mlc_llm/op/cutlass.py @@ -0,0 +1,76 @@ +"""Operators enabled by external modules.""" + +from typing import Optional + +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import op + + +def group_gemm( + x: nn.Tensor, + weight: nn.Tensor, + indptr: nn.Tensor, + scale: Optional[nn.Tensor] = None, + weight_dtype: Optional[str] = None, + out_dtype: Optional[str] = None, +): # pylint: disable=too-many-arguments + """ + Cutlass group gemm operator. + + Parameters + ---------- + x : nn.Tensor + The input tensor, with shape of [m, k]. + + weight : nn.Tensor + The weight tensor, with shape of [num_groups, n, k]. + + indptr : nn.Tensor + The indptr tensor, with shape of [num_groups]. + + scale : Optional[nn.Tensor] + The scale tensor, with shape of [1]. + + weight_dtype: Optional[str] + The data type of the weight tensor. + + out_dtype: Optional[str] + The data type of the output tensor. + + Returns + ------- + nn.Tensor + The output tensor, with shape of [m, n]. + """ + assert x.ndim == 2 + assert weight.ndim == 3 + assert indptr.ndim == 1 + assert weight.shape[2] == x.shape[1] + assert weight.shape[0] == indptr.shape[0] + assert indptr.dtype == "int64" + out_dtype = out_dtype if out_dtype else x.dtype + weight_dtype = weight_dtype if weight_dtype else weight.dtype + + if x.dtype == "e5m2_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + func_name = "cutlass.group_gemm_e5m2_e5m2_fp16" + elif x.dtype == "e4m3_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + func_name = "cutlass.group_gemm_e4m3_e5m2_fp16" + elif x.dtype == "e4m3_float8" and weight.dtype == "e4m3_float8" and out_dtype == "float16": + func_name = "cutlass.group_gemm_e4m3_e4m3_fp16" + elif x.dtype == "float16" and weight.dtype == "float16" and out_dtype == "float16": + func_name = "cutlass.group_gemm_fp16_sm90" + else: + raise NotImplementedError( + f"Unsupported data type: x={x.dtype}, weight={weight.dtype}, out={out_dtype}" + ) + + if "float8" in x.dtype: + assert scale is not None, "scale is required for float8 input" + + workspace = op.empty((4096 * 1024,), dtype="uint8", name="workspace") + + return op.extern( + func_name, + args=[x, weight, indptr, workspace] + ([scale] if scale is not None else []), + out=nn.Tensor.placeholder((x.shape[0], weight.shape[1]), dtype=out_dtype), + ) diff --git a/python/mlc_llm/op/extern.py b/python/mlc_llm/op/extern.py index 5fa7e829f2..fd5d91badb 100644 --- a/python/mlc_llm/op/extern.py +++ b/python/mlc_llm/op/extern.py @@ -28,13 +28,14 @@ class ExternModuleStore: target: Optional[Target] = None flashinfer: bool = False faster_transformer: bool = False + cutlass_group_gemm: bool = False STORE: ExternModuleStore = ExternModuleStore() """Singleton of `ExternModuleStore`.""" -def enable(target: Target, flashinfer: bool, faster_transformer: bool) -> None: +def enable(target: Target, flashinfer: bool, faster_transformer: bool, cutlass: bool) -> None: """Enable external modules. It should be called before any compilation happens.""" global STORE # pylint: disable=global-statement STORE = ExternModuleStore( @@ -42,6 +43,9 @@ def enable(target: Target, flashinfer: bool, faster_transformer: bool) -> None: target=target, flashinfer=flashinfer, faster_transformer=faster_transformer, + cutlass_group_gemm=cutlass + and target.kind.name == "cuda" + and target.attrs.get("arch", "") == "sm_90a", )