Skip to content

Commit 7ff2e9d

Browse files
fairydreamingsszymczy
authored andcommitted
llama : Add support for DeepSeek V3 (ggml-org#11049)
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type * vocab : add DeepSeek V3 pre-tokenizer regexes * unicode : handle ACCENT_MARK and SYMBOL categories in regex * llama : add DeepSeek V3 chat template, handle new model parameters and tensor types --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 50a29b3 commit 7ff2e9d

16 files changed

+162
-5
lines changed

convert_hf_to_gguf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
687687
if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1":
688688
# ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct
689689
res = "megrez"
690+
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
691+
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
692+
res = "deepseek-v3"
690693

691694
if res is None:
692695
logger.warning("\n")
@@ -3849,6 +3852,7 @@ def prepare_tensors(self):
38493852

38503853

38513854
@Model.register("DeepseekV2ForCausalLM")
3855+
@Model.register("DeepseekV3ForCausalLM")
38523856
class DeepseekV2Model(Model):
38533857
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
38543858

@@ -3870,6 +3874,15 @@ def set_gguf_parameters(self):
38703874
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
38713875
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
38723876
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
3877+
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
3878+
3879+
if hparams["scoring_func"] == "sigmoid":
3880+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
3881+
elif hparams["scoring_func"] == "softmax":
3882+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
3883+
else:
3884+
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
3885+
38733886
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
38743887

38753888
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
@@ -3882,6 +3895,16 @@ def set_gguf_parameters(self):
38823895
_experts: list[dict[str, Tensor]] | None = None
38833896

38843897
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3898+
# rename e_score_correction_bias tensors
3899+
if name.endswith("e_score_correction_bias"):
3900+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
3901+
3902+
# skip Multi-Token Prediction (MTP) layers
3903+
block_count = self.hparams["num_hidden_layers"]
3904+
match = re.match(r"model.layers.(\d+)", name)
3905+
if match and int(match.group(1)) >= block_count:
3906+
return []
3907+
38853908
# process the experts separately
38863909
if name.find("mlp.experts") != -1:
38873910
n_experts = self.hparams["n_routed_experts"]

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class TOKENIZER_TYPE(IntEnum):
107107
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
108108
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
109109
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
110+
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
110111
]
111112

112113

gguf-py/gguf/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class LLM:
102102
EXPERT_USED_COUNT = "{arch}.expert_used_count"
103103
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
104104
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
105+
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106+
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
105107
POOLING_TYPE = "{arch}.pooling_type"
106108
LOGIT_SCALE = "{arch}.logit_scale"
107109
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@@ -313,6 +315,7 @@ class MODEL_TENSOR(IntEnum):
313315
FFN_GATE_SHEXP = auto()
314316
FFN_DOWN_SHEXP = auto()
315317
FFN_UP_SHEXP = auto()
318+
FFN_EXP_PROBS_B = auto()
316319
ATTN_Q_NORM = auto()
317320
ATTN_K_NORM = auto()
318321
LAYER_OUT_NORM = auto()
@@ -498,6 +501,7 @@ class MODEL_TENSOR(IntEnum):
498501
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
499502
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
500503
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
504+
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
501505
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
502506
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
503507
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
@@ -1290,6 +1294,7 @@ class MODEL_TENSOR(IntEnum):
12901294
MODEL_TENSOR.FFN_GATE_SHEXP,
12911295
MODEL_TENSOR.FFN_DOWN_SHEXP,
12921296
MODEL_TENSOR.FFN_UP_SHEXP,
1297+
MODEL_TENSOR.FFN_EXP_PROBS_B,
12931298
],
12941299
MODEL_ARCH.CHATGLM : [
12951300
MODEL_TENSOR.TOKEN_EMBD,
@@ -1590,6 +1595,11 @@ class GGMLQuantizationType(IntEnum):
15901595
TQ2_0 = 35
15911596

15921597

1598+
class ExpertGatingFuncType(IntEnum):
1599+
SOFTMAX = 1
1600+
SIGMOID = 2
1601+
1602+
15931603
# TODO: add GGMLFileType from ggml_ftype in ggml.h
15941604

15951605

gguf-py/gguf/gguf_writer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
RopeScalingType,
2727
PoolingType,
2828
TokenType,
29+
ExpertGatingFuncType,
2930
)
3031

3132
from .quants import quant_shape_from_byte_shape
@@ -715,6 +716,12 @@ def add_expert_shared_count(self, count: int) -> None:
715716
def add_expert_weights_scale(self, value: float) -> None:
716717
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
717718

719+
def add_expert_weights_norm(self, value: bool) -> None:
720+
self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
721+
722+
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
723+
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
724+
718725
def add_swin_norm(self, value: bool) -> None:
719726
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
720727

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ class TensorNameMap:
276276
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
277277
),
278278

279+
MODEL_TENSOR.FFN_EXP_PROBS_B: (
280+
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
281+
),
282+
279283
# Feed-forward up
280284
MODEL_TENSOR.FFN_UP: (
281285
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ extern "C" {
105105
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
106106
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
107107
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
108+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
108109
};
109110

110111
enum llama_rope_type {

src/llama-arch.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
9292
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
9393
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
9494
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
95+
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
96+
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
9597
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
9698
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
9799
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -984,6 +986,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
984986
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
985987
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
986988
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
989+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
987990
},
988991
},
989992
{
@@ -1366,6 +1369,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
13661369
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
13671370
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
13681371
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1372+
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
13691373
// this tensor is loaded for T5, but never used
13701374
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
13711375
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ enum llm_kv {
9696
LLM_KV_EXPERT_USED_COUNT,
9797
LLM_KV_EXPERT_SHARED_COUNT,
9898
LLM_KV_EXPERT_WEIGHTS_SCALE,
99+
LLM_KV_EXPERT_WEIGHTS_NORM,
100+
LLM_KV_EXPERT_GATING_FUNC,
99101
LLM_KV_POOLING_TYPE,
100102
LLM_KV_LOGIT_SCALE,
101103
LLM_KV_DECODER_START_TOKEN_ID,
@@ -231,6 +233,7 @@ enum llm_tensor {
231233
LLM_TENSOR_FFN_DOWN_SHEXP,
232234
LLM_TENSOR_FFN_GATE_SHEXP,
233235
LLM_TENSOR_FFN_UP_SHEXP,
236+
LLM_TENSOR_FFN_EXP_PROBS_B,
234237
LLM_TENSOR_ATTN_Q_NORM,
235238
LLM_TENSOR_ATTN_K_NORM,
236239
LLM_TENSOR_LAYER_OUT_NORM,

src/llama-chat.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
4545
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
4646
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
4747
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
48+
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
4849
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
4950
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
5051
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
@@ -148,6 +149,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
148149
return LLM_CHAT_TEMPLATE_MINICPM;
149150
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
150151
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
152+
} else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
153+
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
151154
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
152155
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
153156
// EXAONE-3.0-7.8B-Instruct
@@ -453,6 +456,21 @@ int32_t llm_chat_apply_template(
453456
if (add_ass) {
454457
ss << "Assistant:";
455458
}
459+
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) {
460+
// DeepSeek-V3
461+
for (auto message : chat) {
462+
std::string role(message->role);
463+
if (role == "system") {
464+
ss << message->content << "\n\n";
465+
} else if (role == "user") {
466+
ss << LU8("<|User|>") << message->content;
467+
} else if (role == "assistant") {
468+
ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
469+
}
470+
}
471+
if (add_ass) {
472+
ss << LU8("<|Assistant|>");
473+
}
456474
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
457475
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
458476
// EXAONE-3.0-7.8B-Instruct

src/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ enum llm_chat_template {
2525
LLM_CHAT_TEMPLATE_VICUNA_ORCA,
2626
LLM_CHAT_TEMPLATE_DEEPSEEK,
2727
LLM_CHAT_TEMPLATE_DEEPSEEK_2,
28+
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
2829
LLM_CHAT_TEMPLATE_COMMAND_R,
2930
LLM_CHAT_TEMPLATE_LLAMA_3,
3031
LLM_CHAT_TEMPLATE_CHATGML_3,

0 commit comments

Comments
 (0)