From f9da13f45e7738eb62ebd62aa4fe7777d6e43b65 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 31 Jan 2025 11:27:22 -0500 Subject: [PATCH] [Model] Initial weight absorption impl for Deepseek-v2/v3 (#3092) This PR introduces the initial weight absorption implementation for the Deepseek-v2/v3 models. This implementation will be further enhanced with new KV cache attention interfaces for MLA. --- .../model/deepseek_v2/deepseek_v2_loader.py | 37 ++++++ .../model/deepseek_v2/deepseek_v2_model.py | 117 +++++++++++++++++- .../python/integration/test_model_compile.py | 3 + 3 files changed, 154 insertions(+), 3 deletions(-) diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py index be6dd97836..12a9dd718f 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py @@ -117,6 +117,43 @@ def combine_expert_gate_up(*hf_params, dtype): ), ) + # map MLA kv_b_proj weight + attn = f"model.layers.{i}.self_attn" + mapping.add_mapping( + f"{attn}.w_uk", + [f"{attn}.kv_b_proj.weight"], + functools.partial( + lambda kv_b_proj, dtype: np.split( + kv_b_proj.reshape( + model_config.num_key_value_heads, + model_config.qk_nope_head_dim + model_config.v_head_dim, + model_config.kv_lora_rank, + ), + indices_or_sections=[model_config.qk_nope_head_dim], + axis=1, + )[0] + .transpose(0, 2, 1) + .astype(dtype), + dtype=mlc_param.dtype, + ), + ) + mapping.add_mapping( + f"{attn}.w_uv", + [f"{attn}.kv_b_proj.weight"], + functools.partial( + lambda kv_b_proj, dtype: np.split( + kv_b_proj.reshape( + model_config.num_key_value_heads, + model_config.qk_nope_head_dim + model_config.v_head_dim, + model_config.kv_lora_rank, + ), + indices_or_sections=[model_config.qk_nope_head_dim], + axis=1, + )[1].astype(dtype), + dtype=mlc_param.dtype, + ), + ) + for mlc_name, mlc_param in named_parameters.items(): if mlc_name not in mapping.param_map: mapping.add_mapping( diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py index 566a429003..447d5edf4d 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py @@ -220,6 +220,8 @@ def __init__(self, config: DeepseekV2Config): self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) + self.w_uk = nn.Parameter((self.num_heads, config.kv_lora_rank, self.qk_nope_head_dim)) + self.w_uv = nn.Parameter((self.num_heads, self.v_head_dim, config.kv_lora_rank)) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, @@ -241,6 +243,106 @@ def forward( paged_kv_cache: PagedKVCache, layer_id: int, query_positions: Tensor, + ): + return self.forward_absorb(hidden_states, paged_kv_cache, layer_id, query_positions) + + def forward_absorb( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + layer_id: int, + query_positions: Tensor, + ): + b, s, _ = hidden_states.shape + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj( + self.q_a_layernorm(self.q_a_proj(hidden_states)) + ) # (b, s, num_heads * q_head_dim) + q = op.reshape(q, (b, s, self.num_heads, self.q_head_dim)) # (b, s, num_heads, q_head_dim) + q_nope, q_pe = op.split( + q, [self.qk_nope_head_dim], axis=-1 + ) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim) + q_nope = ( + op.matmul( + q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2), + self.w_uk.permute_dims(0, 2, 1), + ) + .permute_dims(1, 0, 2) + .reshape(b, s, self.num_heads, self.kv_lora_rank) + ) # (b, s, num_heads, kv_lora_rank) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape( + b, s, 1, self.kv_lora_rank + self.qk_rope_head_dim + ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim) + compressed_kv, k_pe = op.split( + compressed_kv, [self.config.kv_lora_rank], axis=-1 + ) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim) + + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_nope = compressed_kv # (b, s, 1, kv_lora_rank) + value_states = compressed_kv # (b, s, 1, kv_lora_rank) + + q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions) + + def concat_nope_pe(num_heads: int): + def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor): + return te.compute( + (b, s, num_heads, self.kv_lora_rank + self.qk_rope_head_dim), + lambda _b, _s, _h, _d: te.if_then_else( + _d < self.kv_lora_rank, + var_nope[_b, _s, _h, _d], + var_pe[_b, _s, _h, _d - self.kv_lora_rank], + ), + ) + + return f_concat_nope_pe + + query_states = op.tensor_expr_op( + concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe] + ) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim) + key_states = op.tensor_expr_op( + concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe] + ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim) + value_states = op.pad( + value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim] + ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim) + + qkv = op.concat( + [query_states, key_states, value_states], dim=2 + ) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim) + output, _ = op.split( + paged_kv_cache.attention_with_fused_qkv( + layer_id, + qkv, + self.num_heads, + self.softmax_scale + * math.sqrt( + self.kv_lora_rank + self.qk_rope_head_dim + ), # This is to cancel out the 1/sqrt(d) in normal attention + ), + indices_or_sections=[self.kv_lora_rank], + axis=-1, + ) # (b, s, num_heads, kv_lora_rank) + output = ( + op.matmul( + output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2), + self.w_uv.permute_dims(0, 2, 1), + ) + .permute_dims(1, 0, 2) + .reshape(b, s, self.num_heads * self.v_head_dim) + ) + + return self.o_proj(output) + + def forward_normal( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + layer_id: int, + query_positions: Tensor, ): b, s, _ = hidden_states.shape @@ -450,6 +552,14 @@ def _set(layer, hint): self.self_attn.kv_b_proj.weight, tp.ShardSingleDim("_shard_kv_b_weight", dim=0), ) + _set( + self.self_attn.w_uk, + tp.ShardSingleDim("_shard_kv_b_weight_w_uk", dim=0), + ) + _set( + self.self_attn.w_uv, + tp.ShardSingleDim("_shard_kv_b_weight_w_uv", dim=0), + ) _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) if isinstance(self.mlp, DeepseekV2MoE): @@ -517,7 +627,6 @@ def __init__(self, config: DeepseekV2Config): def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = inputs - print(f"inputs.shape = {inputs.shape}") query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1]) for layer_id, layer in enumerate(self.layers): hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions) @@ -535,6 +644,8 @@ def __init__(self, config: DeepseekV2Config): self.intermediate_size = config.intermediate_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads + self.kv_lora_rank = config.kv_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim self.rms_norm_eps = config.rms_norm_eps self.rope_theta = config.rope_theta self.vocab_size = config.vocab_size @@ -621,8 +732,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, - num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, - head_dim=256, + num_key_value_heads=1, + head_dim=self.kv_lora_rank + self.qk_rope_head_dim, rope_mode=RopeMode.NONE, rope_scale=1, rope_theta=self.rope_theta, diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py index 3ec70b61b3..dffa582c46 100644 --- a/tests/python/integration/test_model_compile.py +++ b/tests/python/integration/test_model_compile.py @@ -114,6 +114,9 @@ def test_model_compile(): # pylint: disable=too-many-locals if not target.startswith("cuda") and quant == "q4f16_ft": # FasterTransformer only works with cuda continue + if "deepseek_v2" in model and "32" in quant: + # Skip f32 for deepseek v2 model for now. + continue log_file = os.path.join(tmp_dir, f"lib{idx}.log") cmd = [ sys.executable,