Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support weight absorption for DeepSeek-v2 #3115

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 106 additions & 54 deletions python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,91 @@
"""A pass that rewrites KV cache creation functions in IRModule."""

import json
from typing import Any, Dict
from typing import Any, Dict, Literal, Tuple

import tvm
from tvm import IRModule, relax
from tvm.relax.frontend.nn.llm import kv_cache
from tvm.relax.frontend.nn.llm.kv_cache import RopeMode


def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
def extract_creation_args(func: relax.Function) -> Tuple[str, Dict[str, Any]]:
"""Extract the KV cache creation args from the given generic creation func."""
assert isinstance(func.body, relax.SeqExpr)
assert len(func.body.blocks) == 1
assert isinstance(func.body.blocks[0], relax.DataflowBlock)
assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding)
assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call)
assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed")
args = func.body.blocks[0].bindings[0].value.args
assert isinstance(args[0], relax.ExternFunc)
assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic"

assert len(args) == 15
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
assert isinstance(args[2], relax.ShapeExpr)
for i in range(3, 14):
if i in [10, 11]:
continue
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
assert isinstance(args[10], relax.StringImm)
assert isinstance(args[11], (relax.Constant, relax.PrimValue))
assert isinstance(args[14], relax.DataTypeImm)

return {
"max_batch_size": args[1].values[0],
"max_total_seq_len": args[1].values[1],
"prefill_chunk_size": args[1].values[2],
"page_size": args[1].values[3],
"support_sliding_window": args[1].values[4],
"layer_partition": args[2],
"num_hidden_layers": args[3].value.value,
"num_attention_heads": args[4].value.value,
"num_key_value_heads": args[5].value.value,
"head_dim": args[6].value.value,
"rope_mode": args[7].value.value,
"rope_scale": args[8].value.value,
"rope_theta": args[9].value.value,
"rope_scaling": json.loads(args[10].value),
"rope_ext_factors": args[11],
"rotary_dim": args[12].value.value,
"enable_disaggregation": bool(args[13].value.value),
"dtype": args[14].value,
}
call_args = func.body.blocks[0].bindings[0].value.args
assert isinstance(call_args[0], relax.ExternFunc)
assert call_args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
assert isinstance(call_args[1], relax.StringImm)

args = call_args[1:]
if args[0].value == "mha":
assert len(args) == 15
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
assert isinstance(args[2], relax.ShapeExpr)
for i in range(3, 14):
if i in [10, 11]:
continue
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
assert isinstance(args[10], relax.StringImm)
assert isinstance(args[11], (relax.Constant, relax.PrimValue))
assert isinstance(args[14], relax.DataTypeImm)

return "mla", {
"max_batch_size": args[1].values[0],
"max_total_seq_len": args[1].values[1],
"prefill_chunk_size": args[1].values[2],
"page_size": args[1].values[3],
"support_sliding_window": args[1].values[4],
"layer_partition": args[2],
"num_hidden_layers": args[3].value.value,
"num_attention_heads": args[4].value.value,
"num_key_value_heads": args[5].value.value,
"head_dim": args[6].value.value,
"rope_mode": args[7].value.value,
"rope_scale": args[8].value.value,
"rope_theta": args[9].value.value,
"rope_scaling": json.loads(args[10].value),
"rope_ext_factors": args[11],
"rotary_dim": args[12].value.value,
"enable_disaggregation": bool(args[13].value.value),
"dtype": args[14].value,
}
if call_args[1].value == "mla":
assert len(args) == 12
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
assert isinstance(args[2], relax.ShapeExpr)
for i in range(3, 11):
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, tvm.tir.IntImm)
assert isinstance(args[11], relax.DataTypeImm)

return "mla", {
"max_batch_size": args[1].values[0],
"max_total_seq_len": args[1].values[1],
"prefill_chunk_size": args[1].values[2],
"page_size": args[1].values[3],
"support_sliding_window": args[1].values[4],
"layer_partition": args[2],
"num_hidden_layers": args[3].value.value,
"num_attention_heads": args[4].value.value,
"num_key_value_heads": args[5].value.value,
"qk_nope_head_dim": args[6].value.value,
"qk_rope_head_dim": args[7].value.value,
"v_head_dim": args[8].value.value,
"kv_lora_rank": args[9].value.value,
"enable_disaggregation": bool(args[10].value.value),
"dtype": args[11].value,
}

raise ValueError("Cannot reach here")


@tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation")
Expand Down Expand Up @@ -100,24 +132,38 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
if mod.attrs is not None:
new_mod = new_mod.with_attrs(mod.attrs)

kwargs = extract_creation_args(creation_func)
self.attach_kv_cache_metadata(kwargs)
kv_cache_kind, kwargs = extract_creation_args(creation_func)
self.attach_kv_cache_metadata(kv_cache_kind, kwargs)

bb = relax.BlockBuilder(new_mod)
self.create_tir_paged_kv_cache(bb, kwargs)
self.create_flashinfer_paged_kv_cache(bb, kwargs)
self.create_tir_paged_kv_cache(bb, kv_cache_kind, kwargs)
self.create_flashinfer_paged_kv_cache(bb, kv_cache_kind, kwargs)
return bb.finalize()

def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
def attach_kv_cache_metadata(
self, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
):
"""Attach the KV cache metadata to model metadata."""
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["head_dim"],
}

def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None:
if kv_cache_kind == "mha":
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["head_dim"],
}
elif kv_cache_kind == "mla":
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": 1,
"head_dim": kwargs["kv_lora_rank"] + kwargs["qk_rope_head_dim"],
}
else:
raise ValueError("Cannot reach here.")

def create_tir_paged_kv_cache(
self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
) -> None:
"""Create the TIR-based PagedKVCache"""
max_batch_size = relax.Var(
"max_batch_size_", relax.ShapeStructInfo([kwargs["max_batch_size"]])
Expand All @@ -143,16 +189,22 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An
support_sliding_window,
],
):
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
if kv_cache_kind == "mha":
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
elif kv_cache_kind == "mla":
cache = kv_cache.TIRPagedKVCache.create_mla_kv_cache(target=self.target, **kwargs)
else:
raise ValueError("Cannot reach here")
bb.emit_func_output(cache._expr) # pylint: disable=protected-access

def create_flashinfer_paged_kv_cache(
self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]
self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
) -> None:
"""Create the FlashInfer-based PagedKVCache"""
# Filter the cases which FlashInfer does not support.
if ( # pylint: disable=too-many-boolean-expressions
not self.flashinfer
or kv_cache_kind != "mha"
or str(kwargs["dtype"]) != "float16"
or kwargs["head_dim"] != 128
or (
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/baichuan/baichuan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/chatglm3/chatglm3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/cohere/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/deepseek/deepseek_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
41 changes: 11 additions & 30 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tvm.script import tir as T

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.nn import PagedKVCache
from mlc_llm.nn.expert import MixtralExperts
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
Expand Down Expand Up @@ -282,8 +282,6 @@ def forward_absorb(
) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)

compressed_kv = self.kv_a_layernorm(compressed_kv)
k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
value_states = compressed_kv # (b, s, 1, kv_lora_rank)

q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)

Expand All @@ -303,28 +301,9 @@ def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
query_states = op.tensor_expr_op(
concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
key_states = op.tensor_expr_op(
concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
value_states = op.pad(
value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)

qkv = op.concat(
[query_states, key_states, value_states], dim=2
) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
output, _ = op.split(
paged_kv_cache.attention_with_fused_qkv(
layer_id,
qkv,
self.num_heads,
self.softmax_scale
* math.sqrt(
self.kv_lora_rank + self.qk_rope_head_dim
), # This is to cancel out the 1/sqrt(d) in normal attention
),
indices_or_sections=[self.kv_lora_rank],
axis=-1,
output = paged_kv_cache.mla_absorbed(
layer_id, query_states, compressed_kv, k_pe, self.softmax_scale
) # (b, s, num_heads, kv_lora_rank)
output = (
op.matmul(
Expand Down Expand Up @@ -645,7 +624,9 @@ def __init__(self, config: DeepseekV2Config):
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.rope_theta = config.rope_theta
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -724,19 +705,19 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mla(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
page_size=page_size,
support_sliding_window=support_sliding_window,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=1,
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
rope_mode=RopeMode.NONE,
rope_scale=1,
rope_theta=self.rope_theta,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_lora_rank=self.kv_lora_rank,
dtype=self.dtype,
)

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gemma/gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt_j/gpt_j_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt_neox/gpt_neox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/internlm/internlm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
Loading
Loading