Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Initial weight absorption impl for Deepseek-v2/v3 #3092

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,43 @@ def combine_expert_gate_up(*hf_params, dtype):
),
)

# map MLA kv_b_proj weight
attn = f"model.layers.{i}.self_attn"
mapping.add_mapping(
f"{attn}.w_uk",
[f"{attn}.kv_b_proj.weight"],
functools.partial(
lambda kv_b_proj, dtype: np.split(
kv_b_proj.reshape(
model_config.num_key_value_heads,
model_config.qk_nope_head_dim + model_config.v_head_dim,
model_config.kv_lora_rank,
),
indices_or_sections=[model_config.qk_nope_head_dim],
axis=1,
)[0]
.transpose(0, 2, 1)
.astype(dtype),
dtype=mlc_param.dtype,
),
)
mapping.add_mapping(
f"{attn}.w_uv",
[f"{attn}.kv_b_proj.weight"],
functools.partial(
lambda kv_b_proj, dtype: np.split(
kv_b_proj.reshape(
model_config.num_key_value_heads,
model_config.qk_nope_head_dim + model_config.v_head_dim,
model_config.kv_lora_rank,
),
indices_or_sections=[model_config.qk_nope_head_dim],
axis=1,
)[1].astype(dtype),
dtype=mlc_param.dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
Expand Down
117 changes: 114 additions & 3 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def __init__(self, config: DeepseekV2Config):
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.w_uk = nn.Parameter((self.num_heads, config.kv_lora_rank, self.qk_nope_head_dim))
self.w_uv = nn.Parameter((self.num_heads, self.v_head_dim, config.kv_lora_rank))

self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
Expand All @@ -241,6 +243,106 @@ def forward(
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
return self.forward_absorb(hidden_states, paged_kv_cache, layer_id, query_positions)

def forward_absorb(
self,
hidden_states: Tensor,
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
b, s, _ = hidden_states.shape

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(
self.q_a_layernorm(self.q_a_proj(hidden_states))
) # (b, s, num_heads * q_head_dim)
q = op.reshape(q, (b, s, self.num_heads, self.q_head_dim)) # (b, s, num_heads, q_head_dim)
q_nope, q_pe = op.split(
q, [self.qk_nope_head_dim], axis=-1
) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim)
q_nope = (
op.matmul(
q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2),
self.w_uk.permute_dims(0, 2, 1),
)
.permute_dims(1, 0, 2)
.reshape(b, s, self.num_heads, self.kv_lora_rank)
) # (b, s, num_heads, kv_lora_rank)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(
b, s, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
compressed_kv, k_pe = op.split(
compressed_kv, [self.config.kv_lora_rank], axis=-1
) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)

compressed_kv = self.kv_a_layernorm(compressed_kv)
k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
value_states = compressed_kv # (b, s, 1, kv_lora_rank)

q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)

def concat_nope_pe(num_heads: int):
def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
return te.compute(
(b, s, num_heads, self.kv_lora_rank + self.qk_rope_head_dim),
lambda _b, _s, _h, _d: te.if_then_else(
_d < self.kv_lora_rank,
var_nope[_b, _s, _h, _d],
var_pe[_b, _s, _h, _d - self.kv_lora_rank],
),
)

return f_concat_nope_pe

query_states = op.tensor_expr_op(
concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
key_states = op.tensor_expr_op(
concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
value_states = op.pad(
value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)

qkv = op.concat(
[query_states, key_states, value_states], dim=2
) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
output, _ = op.split(
paged_kv_cache.attention_with_fused_qkv(
layer_id,
qkv,
self.num_heads,
self.softmax_scale
* math.sqrt(
self.kv_lora_rank + self.qk_rope_head_dim
), # This is to cancel out the 1/sqrt(d) in normal attention
),
indices_or_sections=[self.kv_lora_rank],
axis=-1,
) # (b, s, num_heads, kv_lora_rank)
output = (
op.matmul(
output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2),
self.w_uv.permute_dims(0, 2, 1),
)
.permute_dims(1, 0, 2)
.reshape(b, s, self.num_heads * self.v_head_dim)
)

return self.o_proj(output)

def forward_normal(
self,
hidden_states: Tensor,
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
b, s, _ = hidden_states.shape

Expand Down Expand Up @@ -450,6 +552,14 @@ def _set(layer, hint):
self.self_attn.kv_b_proj.weight,
tp.ShardSingleDim("_shard_kv_b_weight", dim=0),
)
_set(
self.self_attn.w_uk,
tp.ShardSingleDim("_shard_kv_b_weight_w_uk", dim=0),
)
_set(
self.self_attn.w_uv,
tp.ShardSingleDim("_shard_kv_b_weight_w_uv", dim=0),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))

if isinstance(self.mlp, DeepseekV2MoE):
Expand Down Expand Up @@ -517,7 +627,6 @@ def __init__(self, config: DeepseekV2Config):

def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = inputs
print(f"inputs.shape = {inputs.shape}")
query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])
for layer_id, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions)
Expand All @@ -535,6 +644,8 @@ def __init__(self, config: DeepseekV2Config):
self.intermediate_size = config.intermediate_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.rope_theta = config.rope_theta
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -621,8 +732,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
support_sliding_window=support_sliding_window,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=256,
num_key_value_heads=1,
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
rope_mode=RopeMode.NONE,
rope_scale=1,
rope_theta=self.rope_theta,
Expand Down
3 changes: 3 additions & 0 deletions tests/python/integration/test_model_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def test_model_compile(): # pylint: disable=too-many-locals
if not target.startswith("cuda") and quant == "q4f16_ft":
# FasterTransformer only works with cuda
continue
if "deepseek_v2" in model and "32" in quant:
# Skip f32 for deepseek v2 model for now.
continue
log_file = os.path.join(tmp_dir, f"lib{idx}.log")
cmd = [
sys.executable,
Expand Down
Loading