Skip to content

Commit

Permalink
[Model] Support weight absorption for DeepSeek-v2
Browse files Browse the repository at this point in the history
This PR supports the weight absorption for DeepSeek-v2/v3 models.

At this moment, absorption is enabled by default, while for the first
prefill, computing without weight absorption is usually more efficient.
The logic that switches between weight absorption and normal
computation is still in progress.
  • Loading branch information
MasterJH5574 committed Feb 2, 2025
1 parent f9da13f commit 070472b
Show file tree
Hide file tree
Showing 31 changed files with 200 additions and 114 deletions.
160 changes: 106 additions & 54 deletions python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,91 @@
"""A pass that rewrites KV cache creation functions in IRModule."""

import json
from typing import Any, Dict
from typing import Any, Dict, Literal, Tuple

import tvm
from tvm import IRModule, relax
from tvm.relax.frontend.nn.llm import kv_cache
from tvm.relax.frontend.nn.llm.kv_cache import RopeMode


def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
def extract_creation_args(func: relax.Function) -> Tuple[str, Dict[str, Any]]:
"""Extract the KV cache creation args from the given generic creation func."""
assert isinstance(func.body, relax.SeqExpr)
assert len(func.body.blocks) == 1
assert isinstance(func.body.blocks[0], relax.DataflowBlock)
assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding)
assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call)
assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed")
args = func.body.blocks[0].bindings[0].value.args
assert isinstance(args[0], relax.ExternFunc)
assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic"

assert len(args) == 15
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
assert isinstance(args[2], relax.ShapeExpr)
for i in range(3, 14):
if i in [10, 11]:
continue
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
assert isinstance(args[10], relax.StringImm)
assert isinstance(args[11], (relax.Constant, relax.PrimValue))
assert isinstance(args[14], relax.DataTypeImm)

return {
"max_batch_size": args[1].values[0],
"max_total_seq_len": args[1].values[1],
"prefill_chunk_size": args[1].values[2],
"page_size": args[1].values[3],
"support_sliding_window": args[1].values[4],
"layer_partition": args[2],
"num_hidden_layers": args[3].value.value,
"num_attention_heads": args[4].value.value,
"num_key_value_heads": args[5].value.value,
"head_dim": args[6].value.value,
"rope_mode": args[7].value.value,
"rope_scale": args[8].value.value,
"rope_theta": args[9].value.value,
"rope_scaling": json.loads(args[10].value),
"rope_ext_factors": args[11],
"rotary_dim": args[12].value.value,
"enable_disaggregation": bool(args[13].value.value),
"dtype": args[14].value,
}
call_args = func.body.blocks[0].bindings[0].value.args
assert isinstance(call_args[0], relax.ExternFunc)
assert call_args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
assert isinstance(call_args[1], relax.StringImm)

args = call_args[1:]
if args[0].value == "mha":
assert len(args) == 15
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
assert isinstance(args[2], relax.ShapeExpr)
for i in range(3, 14):
if i in [10, 11]:
continue
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
assert isinstance(args[10], relax.StringImm)
assert isinstance(args[11], (relax.Constant, relax.PrimValue))
assert isinstance(args[14], relax.DataTypeImm)

return "mla", {
"max_batch_size": args[1].values[0],
"max_total_seq_len": args[1].values[1],
"prefill_chunk_size": args[1].values[2],
"page_size": args[1].values[3],
"support_sliding_window": args[1].values[4],
"layer_partition": args[2],
"num_hidden_layers": args[3].value.value,
"num_attention_heads": args[4].value.value,
"num_key_value_heads": args[5].value.value,
"head_dim": args[6].value.value,
"rope_mode": args[7].value.value,
"rope_scale": args[8].value.value,
"rope_theta": args[9].value.value,
"rope_scaling": json.loads(args[10].value),
"rope_ext_factors": args[11],
"rotary_dim": args[12].value.value,
"enable_disaggregation": bool(args[13].value.value),
"dtype": args[14].value,
}
if call_args[1].value == "mla":
assert len(args) == 12
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
assert isinstance(args[2], relax.ShapeExpr)
for i in range(3, 11):
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, tvm.tir.IntImm)
assert isinstance(args[11], relax.DataTypeImm)

return "mla", {
"max_batch_size": args[1].values[0],
"max_total_seq_len": args[1].values[1],
"prefill_chunk_size": args[1].values[2],
"page_size": args[1].values[3],
"support_sliding_window": args[1].values[4],
"layer_partition": args[2],
"num_hidden_layers": args[3].value.value,
"num_attention_heads": args[4].value.value,
"num_key_value_heads": args[5].value.value,
"qk_nope_head_dim": args[6].value.value,
"qk_rope_head_dim": args[7].value.value,
"v_head_dim": args[8].value.value,
"kv_lora_rank": args[9].value.value,
"enable_disaggregation": bool(args[10].value.value),
"dtype": args[11].value,
}

raise ValueError("Cannot reach here")


@tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation")
Expand Down Expand Up @@ -100,24 +132,38 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
if mod.attrs is not None:
new_mod = new_mod.with_attrs(mod.attrs)

kwargs = extract_creation_args(creation_func)
self.attach_kv_cache_metadata(kwargs)
kv_cache_kind, kwargs = extract_creation_args(creation_func)
self.attach_kv_cache_metadata(kv_cache_kind, kwargs)

bb = relax.BlockBuilder(new_mod)
self.create_tir_paged_kv_cache(bb, kwargs)
self.create_flashinfer_paged_kv_cache(bb, kwargs)
self.create_tir_paged_kv_cache(bb, kv_cache_kind, kwargs)
self.create_flashinfer_paged_kv_cache(bb, kv_cache_kind, kwargs)
return bb.finalize()

def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
def attach_kv_cache_metadata(
self, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
):
"""Attach the KV cache metadata to model metadata."""
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["head_dim"],
}

def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None:
if kv_cache_kind == "mha":
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["head_dim"],
}
elif kv_cache_kind == "mla":
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": 1,
"head_dim": kwargs["kv_lora_rank"] + kwargs["qk_rope_head_dim"],
}
else:
raise ValueError("Cannot reach here.")

def create_tir_paged_kv_cache(
self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
) -> None:
"""Create the TIR-based PagedKVCache"""
max_batch_size = relax.Var(
"max_batch_size_", relax.ShapeStructInfo([kwargs["max_batch_size"]])
Expand All @@ -143,16 +189,22 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An
support_sliding_window,
],
):
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
if kv_cache_kind == "mha":
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
elif kv_cache_kind == "mla":
cache = kv_cache.TIRPagedKVCache.create_mla_kv_cache(target=self.target, **kwargs)
else:
raise ValueError("Cannot reach here")
bb.emit_func_output(cache._expr) # pylint: disable=protected-access

def create_flashinfer_paged_kv_cache(
self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]
self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
) -> None:
"""Create the FlashInfer-based PagedKVCache"""
# Filter the cases which FlashInfer does not support.
if ( # pylint: disable=too-many-boolean-expressions
not self.flashinfer
or kv_cache_kind != "mha"
or str(kwargs["dtype"]) != "float16"
or kwargs["head_dim"] != 128
or (
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/baichuan/baichuan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/chatglm3/chatglm3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/cohere/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/deepseek/deepseek_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
41 changes: 11 additions & 30 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tvm.script import tir as T

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.nn import PagedKVCache
from mlc_llm.nn.expert import MixtralExperts
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
Expand Down Expand Up @@ -282,8 +282,6 @@ def forward_absorb(
) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)

compressed_kv = self.kv_a_layernorm(compressed_kv)
k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
value_states = compressed_kv # (b, s, 1, kv_lora_rank)

q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)

Expand All @@ -303,28 +301,9 @@ def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
query_states = op.tensor_expr_op(
concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
key_states = op.tensor_expr_op(
concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
value_states = op.pad(
value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)

qkv = op.concat(
[query_states, key_states, value_states], dim=2
) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
output, _ = op.split(
paged_kv_cache.attention_with_fused_qkv(
layer_id,
qkv,
self.num_heads,
self.softmax_scale
* math.sqrt(
self.kv_lora_rank + self.qk_rope_head_dim
), # This is to cancel out the 1/sqrt(d) in normal attention
),
indices_or_sections=[self.kv_lora_rank],
axis=-1,
output = paged_kv_cache.mla_absorbed(
layer_id, query_states, compressed_kv, k_pe, self.softmax_scale
) # (b, s, num_heads, kv_lora_rank)
output = (
op.matmul(
Expand Down Expand Up @@ -645,7 +624,9 @@ def __init__(self, config: DeepseekV2Config):
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.rope_theta = config.rope_theta
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -724,19 +705,19 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mla(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
page_size=page_size,
support_sliding_window=support_sliding_window,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=1,
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
rope_mode=RopeMode.NONE,
rope_scale=1,
rope_theta=self.rope_theta,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_lora_rank=self.kv_lora_rank,
dtype=self.dtype,
)

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gemma/gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt_j/gpt_j_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/gpt_neox/gpt_neox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/model/internlm/internlm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size: tir.Var,
support_sliding_window: tir.Var,
) -> PagedKVCache:
return PagedKVCache.create_generic(
return PagedKVCache.create_generic_mha(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
Expand Down
Loading

0 comments on commit 070472b

Please sign in to comment.