diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 53eeb9df25..b729935ccd 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -1,7 +1,7 @@ """A pass that rewrites KV cache creation functions in IRModule.""" import json -from typing import Any, Dict +from typing import Any, Dict, Literal, Tuple import tvm from tvm import IRModule, relax @@ -9,7 +9,7 @@ from tvm.relax.frontend.nn.llm.kv_cache import RopeMode -def extract_creation_args(func: relax.Function) -> Dict[str, Any]: +def extract_creation_args(func: relax.Function) -> Tuple[Literal["mha", "mla"], Dict[str, Any]]: """Extract the KV cache creation args from the given generic creation func.""" assert isinstance(func.body, relax.SeqExpr) assert len(func.body.blocks) == 1 @@ -17,43 +17,75 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding) assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call) assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed") - args = func.body.blocks[0].bindings[0].value.args - assert isinstance(args[0], relax.ExternFunc) - assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic" - - assert len(args) == 15 - assert isinstance(args[1], relax.ShapeExpr) - assert len(args[1].values) == 5 - assert isinstance(args[2], relax.ShapeExpr) - for i in range(3, 14): - if i in [10, 11]: - continue - assert isinstance(args[i], relax.PrimValue) - assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm)) - assert isinstance(args[10], relax.StringImm) - assert isinstance(args[11], (relax.Constant, relax.PrimValue)) - assert isinstance(args[14], relax.DataTypeImm) - - return { - "max_batch_size": args[1].values[0], - "max_total_seq_len": args[1].values[1], - "prefill_chunk_size": args[1].values[2], - "page_size": args[1].values[3], - "support_sliding_window": args[1].values[4], - "layer_partition": args[2], - "num_hidden_layers": args[3].value.value, - "num_attention_heads": args[4].value.value, - "num_key_value_heads": args[5].value.value, - "head_dim": args[6].value.value, - "rope_mode": args[7].value.value, - "rope_scale": args[8].value.value, - "rope_theta": args[9].value.value, - "rope_scaling": json.loads(args[10].value), - "rope_ext_factors": args[11], - "rotary_dim": args[12].value.value, - "enable_disaggregation": bool(args[13].value.value), - "dtype": args[14].value, - } + call_args = func.body.blocks[0].bindings[0].value.args + assert isinstance(call_args[0], relax.ExternFunc) + assert call_args[0].global_symbol == "mlc.create_paged_kv_cache_generic" + assert isinstance(call_args[1], relax.StringImm) + + args = call_args[1:] + if args[0].value == "mha": + assert len(args) == 15 + assert isinstance(args[1], relax.ShapeExpr) + assert len(args[1].values) == 5 + assert isinstance(args[2], relax.ShapeExpr) + for i in range(3, 14): + if i in [10, 11]: + continue + assert isinstance(args[i], relax.PrimValue) + assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm)) + assert isinstance(args[10], relax.StringImm) + assert isinstance(args[11], (relax.Constant, relax.PrimValue)) + assert isinstance(args[14], relax.DataTypeImm) + + return "mha", { + "max_batch_size": args[1].values[0], + "max_total_seq_len": args[1].values[1], + "prefill_chunk_size": args[1].values[2], + "page_size": args[1].values[3], + "support_sliding_window": args[1].values[4], + "layer_partition": args[2], + "num_hidden_layers": args[3].value.value, + "num_attention_heads": args[4].value.value, + "num_key_value_heads": args[5].value.value, + "head_dim": args[6].value.value, + "rope_mode": args[7].value.value, + "rope_scale": args[8].value.value, + "rope_theta": args[9].value.value, + "rope_scaling": json.loads(args[10].value), + "rope_ext_factors": args[11], + "rotary_dim": args[12].value.value, + "enable_disaggregation": bool(args[13].value.value), + "dtype": args[14].value, + } + if call_args[1].value == "mla": + assert len(args) == 12 + assert isinstance(args[1], relax.ShapeExpr) + assert len(args[1].values) == 5 + assert isinstance(args[2], relax.ShapeExpr) + for i in range(3, 11): + assert isinstance(args[i], relax.PrimValue) + assert isinstance(args[i].value, tvm.tir.IntImm) + assert isinstance(args[11], relax.DataTypeImm) + + return "mla", { + "max_batch_size": args[1].values[0], + "max_total_seq_len": args[1].values[1], + "prefill_chunk_size": args[1].values[2], + "page_size": args[1].values[3], + "support_sliding_window": args[1].values[4], + "layer_partition": args[2], + "num_hidden_layers": args[3].value.value, + "num_attention_heads": args[4].value.value, + "num_key_value_heads": args[5].value.value, + "qk_nope_head_dim": args[6].value.value, + "qk_rope_head_dim": args[7].value.value, + "v_head_dim": args[8].value.value, + "kv_lora_rank": args[9].value.value, + "enable_disaggregation": bool(args[10].value.value), + "dtype": args[11].value, + } + + raise ValueError("Cannot reach here") @tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation") @@ -100,24 +132,38 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR if mod.attrs is not None: new_mod = new_mod.with_attrs(mod.attrs) - kwargs = extract_creation_args(creation_func) - self.attach_kv_cache_metadata(kwargs) + kv_cache_kind, kwargs = extract_creation_args(creation_func) + self.attach_kv_cache_metadata(kv_cache_kind, kwargs) bb = relax.BlockBuilder(new_mod) - self.create_tir_paged_kv_cache(bb, kwargs) - self.create_flashinfer_paged_kv_cache(bb, kwargs) + self.create_tir_paged_kv_cache(bb, kv_cache_kind, kwargs) + self.create_flashinfer_paged_kv_cache(bb, kv_cache_kind, kwargs) return bb.finalize() - def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]): + def attach_kv_cache_metadata( + self, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any] + ): """Attach the KV cache metadata to model metadata.""" - self.metadata["kv_cache"] = { - "num_hidden_layers": kwargs["num_hidden_layers"], - "num_attention_heads": kwargs["num_attention_heads"], - "num_key_value_heads": kwargs["num_key_value_heads"], - "head_dim": kwargs["head_dim"], - } - - def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None: + if kv_cache_kind == "mha": + self.metadata["kv_cache"] = { + "num_hidden_layers": kwargs["num_hidden_layers"], + "num_attention_heads": kwargs["num_attention_heads"], + "num_key_value_heads": kwargs["num_key_value_heads"], + "head_dim": kwargs["head_dim"], + } + elif kv_cache_kind == "mla": + self.metadata["kv_cache"] = { + "num_hidden_layers": kwargs["num_hidden_layers"], + "num_attention_heads": kwargs["num_attention_heads"], + "num_key_value_heads": 1, + "head_dim": kwargs["kv_lora_rank"] + kwargs["qk_rope_head_dim"], + } + else: + raise ValueError("Cannot reach here.") + + def create_tir_paged_kv_cache( + self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any] + ) -> None: """Create the TIR-based PagedKVCache""" max_batch_size = relax.Var( "max_batch_size_", relax.ShapeStructInfo([kwargs["max_batch_size"]]) @@ -143,16 +189,22 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An support_sliding_window, ], ): - cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs) + if kv_cache_kind == "mha": + cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs) + elif kv_cache_kind == "mla": + cache = kv_cache.TIRPagedKVCache.create_mla_kv_cache(target=self.target, **kwargs) + else: + raise ValueError("Cannot reach here") bb.emit_func_output(cache._expr) # pylint: disable=protected-access def create_flashinfer_paged_kv_cache( - self, bb: relax.BlockBuilder, kwargs: Dict[str, Any] + self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any] ) -> None: """Create the FlashInfer-based PagedKVCache""" # Filter the cases which FlashInfer does not support. if ( # pylint: disable=too-many-boolean-expressions not self.flashinfer + or kv_cache_kind != "mha" or str(kwargs["dtype"]) != "float16" or kwargs["head_dim"] != 128 or ( diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index f3779684b9..fb4c9c1128 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -277,7 +277,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index 8dd987882b..414c74ddc2 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -353,7 +353,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index a9c4fe7edd..f60da549c6 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -322,7 +322,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/deepseek/deepseek_model.py b/python/mlc_llm/model/deepseek/deepseek_model.py index cdab6e1935..6c4a0a7ccd 100644 --- a/python/mlc_llm/model/deepseek/deepseek_model.py +++ b/python/mlc_llm/model/deepseek/deepseek_model.py @@ -428,7 +428,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py index 447d5edf4d..0eda9ae80c 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py @@ -13,7 +13,7 @@ from tvm.script import tir as T from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache from mlc_llm.nn.expert import MixtralExperts from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp @@ -282,8 +282,6 @@ def forward_absorb( ) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim) compressed_kv = self.kv_a_layernorm(compressed_kv) - k_nope = compressed_kv # (b, s, 1, kv_lora_rank) - value_states = compressed_kv # (b, s, 1, kv_lora_rank) q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions) @@ -303,28 +301,9 @@ def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor): query_states = op.tensor_expr_op( concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe] ) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim) - key_states = op.tensor_expr_op( - concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe] - ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim) - value_states = op.pad( - value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim] - ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim) - qkv = op.concat( - [query_states, key_states, value_states], dim=2 - ) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim) - output, _ = op.split( - paged_kv_cache.attention_with_fused_qkv( - layer_id, - qkv, - self.num_heads, - self.softmax_scale - * math.sqrt( - self.kv_lora_rank + self.qk_rope_head_dim - ), # This is to cancel out the 1/sqrt(d) in normal attention - ), - indices_or_sections=[self.kv_lora_rank], - axis=-1, + output = paged_kv_cache.mla_absorbed( + layer_id, query_states, compressed_kv, k_pe, self.softmax_scale ) # (b, s, num_heads, kv_lora_rank) output = ( op.matmul( @@ -645,7 +624,9 @@ def __init__(self, config: DeepseekV2Config): self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim self.rms_norm_eps = config.rms_norm_eps self.rope_theta = config.rope_theta self.vocab_size = config.vocab_size @@ -724,7 +705,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mla( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, @@ -732,11 +713,11 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, - num_key_value_heads=1, - head_dim=self.kv_lora_rank + self.qk_rope_head_dim, - rope_mode=RopeMode.NONE, - rope_scale=1, - rope_theta=self.rope_theta, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + kv_lora_rank=self.kv_lora_rank, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 9d7820b841..60dbc30601 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -164,7 +164,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index d74e84ab4a..f946fde635 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -306,7 +306,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 02b9a2c10d..fc42d3349e 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -297,7 +297,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index c6cd09018c..daf2add94d 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -264,7 +264,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/gpt_j/gpt_j_model.py b/python/mlc_llm/model/gpt_j/gpt_j_model.py index 08d365283d..527828beb0 100644 --- a/python/mlc_llm/model/gpt_j/gpt_j_model.py +++ b/python/mlc_llm/model/gpt_j/gpt_j_model.py @@ -285,7 +285,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 1f456f1aae..88555a5309 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -328,7 +328,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index 4a4cc8d544..9b8d05c27b 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -288,7 +288,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/internlm2/internlm2_model.py b/python/mlc_llm/model/internlm2/internlm2_model.py index c075d8db23..8ef09ba499 100644 --- a/python/mlc_llm/model/internlm2/internlm2_model.py +++ b/python/mlc_llm/model/internlm2/internlm2_model.py @@ -295,7 +295,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 181baae8a4..12537e0e9c 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -396,7 +396,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index 430243404e..e414ca180a 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -229,7 +229,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/minicpm/minicpm_model.py b/python/mlc_llm/model/minicpm/minicpm_model.py index d37d2ed9e5..7023c27e7e 100644 --- a/python/mlc_llm/model/minicpm/minicpm_model.py +++ b/python/mlc_llm/model/minicpm/minicpm_model.py @@ -428,7 +428,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 854ce5be56..6dc1a9a2bd 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -302,7 +302,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/nemotron/nemotron_model.py b/python/mlc_llm/model/nemotron/nemotron_model.py index 7d22a2f4b8..78c64f3adb 100644 --- a/python/mlc_llm/model/nemotron/nemotron_model.py +++ b/python/mlc_llm/model/nemotron/nemotron_model.py @@ -378,7 +378,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index e4ecb998db..3be06e3782 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -450,7 +450,7 @@ def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-man page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index c57fb8d1e0..2e13752f4a 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -284,7 +284,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 7becb7bdb4..71cfc1201f 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -406,7 +406,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/phi3/phi3_model.py b/python/mlc_llm/model/phi3/phi3_model.py index fc3d762a00..2619933e06 100644 --- a/python/mlc_llm/model/phi3/phi3_model.py +++ b/python/mlc_llm/model/phi3/phi3_model.py @@ -302,7 +302,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/phi3v/phi3v_model.py b/python/mlc_llm/model/phi3v/phi3v_model.py index 5097f45149..1790a6c196 100644 --- a/python/mlc_llm/model/phi3v/phi3v_model.py +++ b/python/mlc_llm/model/phi3v/phi3v_model.py @@ -282,7 +282,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 00e972e0ea..0474ee44a6 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -283,7 +283,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index 93d2936125..dcb613198e 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -324,7 +324,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py index 59b7ae8375..aa55927f4a 100644 --- a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py @@ -307,7 +307,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 224c1c3915..1a61141dbf 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -292,7 +292,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/model/starcoder2/starcoder2_model.py b/python/mlc_llm/model/starcoder2/starcoder2_model.py index 6477401006..6796bc6890 100644 --- a/python/mlc_llm/model/starcoder2/starcoder2_model.py +++ b/python/mlc_llm/model/starcoder2/starcoder2_model.py @@ -310,7 +310,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index d5a609e470..c1de92a73c 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -15,7 +15,7 @@ class PagedKVCache(TVMPagedKVCache): # pylint: disable=too-few-public-methods """The Paged KV Cache used in LLM batching for efficient attention computation.""" @staticmethod - def create_generic( # pylint: disable=too-many-locals + def create_generic_mha( # pylint: disable=too-many-locals max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -36,7 +36,7 @@ def create_generic( # pylint: disable=too-many-locals enable_disaggregation: bool = False, name: str = "paged_kv_cache", ) -> "PagedKVCache": - """The generic function of creating a PagedKVCache, + """The generic function of creating a multi-head attention PagedKVCache, which will be rewritten by functions in compilation pipeline. """ if rotary_dim is None: @@ -48,6 +48,7 @@ def create_generic( # pylint: disable=too-many-locals return PagedKVCache( _expr=rx.call_pure_packed( "mlc.create_paged_kv_cache_generic", + rx.StringImm("mha"), rx.ShapeExpr( [ max_batch_size, @@ -80,3 +81,55 @@ def create_generic( # pylint: disable=too-many-locals ), _name=name, ) + + @staticmethod + def create_generic_mla( # pylint: disable=too-many-locals + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + dtype: str, + layer_partition: Optional[List[int]] = None, + enable_disaggregation: bool = False, + name: str = "paged_kv_cache", + ) -> "PagedKVCache": + """The generic function of creating a multi-head latent attn PagedKVCache, + which will be rewritten by functions in compilation pipeline. + """ + if layer_partition is None: + layer_partition = [0, num_hidden_layers] + return PagedKVCache( + _expr=rx.call_pure_packed( + "mlc.create_paged_kv_cache_generic", + rx.StringImm("mla"), + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), + rx.ShapeExpr(layer_partition), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(qk_nope_head_dim), + rx.PrimValue(qk_rope_head_dim), + rx.PrimValue(v_head_dim), + rx.PrimValue(kv_lora_rank), + rx.PrimValue(int(enable_disaggregation)), + rx.DataTypeImm(dtype), + sinfo_args=rx.ObjectStructInfo(), + ), + _name=name, + ) diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py index aaa8e39104..02851445e9 100644 --- a/tests/python/model/test_kv_cache.py +++ b/tests/python/model/test_kv_cache.py @@ -74,7 +74,7 @@ def create_paged_kv_cache( page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return PagedKVCache.create_generic_mha( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size,