Skip to content

Commit

Permalink
[Model] Initial weight absorption impl for Deepseek-v2/v3
Browse files Browse the repository at this point in the history
This PR introduces the initial weight absorption implementation
for the Deepseek-v2/v3 models. This implementation will be further
enhanced with new KV cache attention interfaces for MLA.
  • Loading branch information
MasterJH5574 committed Jan 20, 2025
1 parent b8838a1 commit 8f7f9ba
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 3 deletions.
37 changes: 37 additions & 0 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,43 @@ def combine_expert_gate_up(*hf_params, dtype):
),
)

# map MLA kv_b_proj weight
attn = f"model.layers.{i}.self_attn"
mapping.add_mapping(
f"{attn}.w_uk",
[f"{attn}.kv_b_proj.weight"],
functools.partial(
lambda kv_b_proj, dtype: np.split(
kv_b_proj.reshape(
model_config.num_key_value_heads,
model_config.qk_nope_head_dim + model_config.v_head_dim,
model_config.kv_lora_rank,
),
indices_or_sections=[model_config.qk_nope_head_dim],
axis=1,
)[0]
.transpose(0, 2, 1)
.astype(dtype),
dtype=mlc_param.dtype,
),
)
mapping.add_mapping(
f"{attn}.w_uv",
[f"{attn}.kv_b_proj.weight"],
functools.partial(
lambda kv_b_proj, dtype: np.split(
kv_b_proj.reshape(
model_config.num_key_value_heads,
model_config.qk_nope_head_dim + model_config.v_head_dim,
model_config.kv_lora_rank,
),
indices_or_sections=[model_config.qk_nope_head_dim],
axis=1,
)[1].astype(dtype),
dtype=mlc_param.dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
Expand Down
117 changes: 114 additions & 3 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def __init__(self, config: DeepseekV2Config):
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.w_uk = nn.Parameter((self.num_heads, config.kv_lora_rank, self.qk_nope_head_dim))
self.w_uv = nn.Parameter((self.num_heads, self.v_head_dim, config.kv_lora_rank))

self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
Expand All @@ -235,6 +237,106 @@ def forward(
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
return self.forward_absorb(hidden_states, paged_kv_cache, layer_id, query_positions)

def forward_absorb(
self,
hidden_states: Tensor,
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
b, s, _ = hidden_states.shape

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(
self.q_a_layernorm(self.q_a_proj(hidden_states))
) # (b, s, num_heads * q_head_dim)
q = op.reshape(q, (b, s, self.num_heads, self.q_head_dim)) # (b, s, num_heads, q_head_dim)
q_nope, q_pe = op.split(
q, [self.qk_nope_head_dim], axis=-1
) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim)
q_nope = (
op.matmul(
q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2),
self.w_uk.permute_dims(0, 2, 1),
)
.permute_dims(1, 0, 2)
.reshape(b, s, self.num_heads, self.kv_lora_rank)
) # (b, s, num_heads, kv_lora_rank)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(
b, s, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
compressed_kv, k_pe = op.split(
compressed_kv, [self.config.kv_lora_rank], axis=-1
) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)

compressed_kv = self.kv_a_layernorm(compressed_kv)
k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
value_states = compressed_kv # (b, s, 1, kv_lora_rank)

q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)

def concat_nope_pe(num_heads: int):
def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
return te.compute(
(b, s, num_heads, self.kv_lora_rank + self.qk_rope_head_dim),
lambda _b, _s, _h, _d: te.if_then_else(
_d < self.kv_lora_rank,
var_nope[_b, _s, _h, _d],
var_pe[_b, _s, _h, _d - self.kv_lora_rank],
),
)

return f_concat_nope_pe

query_states = op.tensor_expr_op(
concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
key_states = op.tensor_expr_op(
concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
value_states = op.pad(
value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)

qkv = op.concat(
[query_states, key_states, value_states], dim=2
) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
output, _ = op.split(
paged_kv_cache.attention_with_fused_qkv(
layer_id,
qkv,
self.num_heads,
self.softmax_scale
* math.sqrt(
self.kv_lora_rank + self.qk_rope_head_dim
), # This is to cancel out the 1/sqrt(d) in normal attention
),
indices_or_sections=[self.kv_lora_rank],
axis=-1,
) # (b, s, num_heads, kv_lora_rank)
output = (
op.matmul(
output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2),
self.w_uv.permute_dims(0, 2, 1),
)
.permute_dims(1, 0, 2)
.reshape(b, s, self.num_heads * self.v_head_dim)
)

return self.o_proj(output)

def forward_normal(
self,
hidden_states: Tensor,
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
b, s, _ = hidden_states.shape

Expand Down Expand Up @@ -444,6 +546,14 @@ def _set(layer, hint):
self.self_attn.kv_b_proj.weight,
tp.ShardSingleDim("_shard_kv_b_weight", dim=0),
)
_set(
self.self_attn.w_uk,
tp.ShardSingleDim("_shard_kv_b_weight_w_uk", dim=0),
)
_set(
self.self_attn.w_uv,
tp.ShardSingleDim("_shard_kv_b_weight_w_uv", dim=0),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))

if isinstance(self.mlp, DeepseekV2MoE):
Expand Down Expand Up @@ -511,7 +621,6 @@ def __init__(self, config: DeepseekV2Config):

def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = inputs
print(f"inputs.shape = {inputs.shape}")
query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])
for layer_id, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions)
Expand All @@ -529,6 +638,8 @@ def __init__(self, config: DeepseekV2Config):
self.intermediate_size = config.intermediate_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.rope_theta = config.rope_theta
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -615,8 +726,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
support_sliding_window=support_sliding_window,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=256,
num_key_value_heads=1,
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
rope_mode=RopeMode.NONE,
rope_scale=1,
rope_theta=self.rope_theta,
Expand Down

0 comments on commit 8f7f9ba

Please sign in to comment.