-
Notifications
You must be signed in to change notification settings - Fork 12.4k
model : add hunyuan moe #14425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model : add hunyuan moe #14425
Conversation
Ok, getting somewhere now. The model runs, but output gibberish
|
Thanks for working on this! I got the same looking output trying The only odd things I noticed were:
Tested on an AMD 7965WX 24x Core 256GB DDR5@4800 + Dual RTX A6000 (96GB Total VRAM) rig. 👈 a few more commands and logs fwiwconvertpython \
convert_hf_to_gguf.py \
--outtype bf16 \
--split-max-size 50G \
--outfile /mnt/raid/models/ubergarm/Hunyuan-A13B-Instruct-GGUF/ \
/mnt/raid/models/tencent/Hunyuan-A13B-Instruct/
... llama-servermodel=/mnt/raid/models/ubergarm/Hunyuan-A13B-Instruct-GGUF/Hunyuan-A13B-Instruct-BF16-00001-of-00004.gguf
./build/bin/llama-server \
--model "$model" \
-fa \
-ctk f16 -ctv f16 \
-c 8192 \
-ts 48,48 \
-ngl 10 \
--threads 24 \
--host 127.0.0.1 \
--port 8080
... client>>> User:
Tell a funny joke in English.
>>> Assistant:
[UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧][UNK_BYTE_0xe697a7新旧][UNK_BYTE_0xe696b0新旧] |
I don't know as much about this as you guys but, could it be that the tokenizer is splitting characters like 新 ("new") into raw bytes? So the UTF-8 sequence And so the fragments get wrapped in Because common Chinese characters always use 3 bytes in UTF-8:
It matches the error: |
The cgraph is still not correct. Testing with this tiny random weight: https://huggingface.co/ngxson/hunyuan-moe-tiny-random/tree/main Seems like the problem is from the self-attention block |
I don't know if the improvements I am seeing are from your last The changes I made were:
my edits are here: https://github.com/kooshi/llama.cpp/tree/hunyuan
|
The more looking at the upstream implementation, the more I wonder if it actually works. My Mac M3 Ultra can't load the original model even though having 512GB of RAM. Now, testing with the tiny weight. Switching between Also, And more importantly, If that is true, it means they messed up badly this time. |
https://www.diffchecker.com/P3e0hQM5/ https://huggingface.co/tencent/Tencent-Hunyuan-Large/blob/main/Hunyuan-A52B-Instruct/ And https://www.diffchecker.com/P9FIR5OD/ In other words, its almost Hunyuan large? I'm not sure why the HF attention implementations would be bugged. But other reimplementations like vllm's seem to work, so maybe they can shed some light on this: |
I take that back, apparently vllm is only sometimes working with A13B, heh: |
I had the original model from Huggingface work coherently on pure CPU. It uses the HunYuanSdpaAttention codepath. This is all tentative as I just got it running at all: If I compare logits for a single-token prompt, I get a very similar logit distribution from both llama.cpp and the HF. More than one token and things look different. I'm purely going with numerical token IDs for llama.cpp as the tokenizer is messed up as observed (I tried 'a' the token 64 for single-token prompt and '12' prompt (16, 17) for two-token test, e.g. This is with the code from combined @ngxson and @kooshi with the .gguf made with @kooshi 's code (I took latest efforts I saw here in the discussion to start off). Below in the dropbox is the My machine has 256GB of memory, a Hetzner server with a modern AMD EPYC CPU. I do have a Mac Studio (M2, 192GB) as well but for CPU work this Hetzner is usually much faster. (I don't know why asking it to use bfloat16 helps, maybe it doesn't make giant copies of tensors or something when you ask it to do that; it's just something I observed and never checked what's it doing behind the scenes). test.pyThis is a version of the example code from the Huggingface page that I modified a bit. #!/usr/bin/env python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import re
def main():
with torch.no_grad():
model_path = '/home/shannon/llama.cpp/tencent_Hunyuan-A13B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
messages = [
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt",
enable_thinking=True # Toggle thinking mode (default: True)
)
outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=20)
output_text = tokenizer.decode(outputs[0])
print(outputs)
print(output_text)
if __name__ == '__main__':
main() stdout of test.pyThe output has output as token IDs and as text (two prints()) in there. To run this, you need to install
I'm on and off this weekend trying to also figure out where computation graph is off exactly. If I find out before someone else does, I'll let you all know. (Runs surprisingly fast on transformers+CPU, I'm used to that combo being extraordinarily slow. It is still very slow, just not like "it will take 30 minutes to make 10 tokens" slow). |
Is it possible to load this model in 4-bit precision using Transformers? Does bitsandbytes support this model? I’m limited to a total of 72GB of VRAM across several GPUs, so bfloat16 won’t work for me. |
Their official inference script for running the int4 quant on vllm is using (still didn't work for me though) |
To add to @ubergarm options, I did notice there are some quantized versions like https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8 or https://huggingface.co/tencent/Hunyuan-A13B-Instruct-GPTQ-Int4 (they look like they are designed to work with The GPTQ-Int4 one has a single Haven't tried any of them. For computation graph work feels better to get whatever is highest precision I am able to run conveniently. |
If someone can run it, could you please verify if |
@ngxson is this the part you wanted to see if it's None or not? Argument to the forward()? ![]() Edit: took a bigger screenshot to show more clearly where I put that. Stdout tail because that first paste is cut off, I see
Edit2: I'm going to let this thing generate a full response which might take a while. But I feel this might be a bit short as a test; it almost verbatim mentions the prompt in the <think> so maybe it's about to repeat itself or something. I'll paste as a new comment when it's done. Just want to get more confirmation the HF implementation itself works beyond very short generations. |
Full response example of the stdout from test2.py (I cut off all the parts that said attention mask is None)
Code is almost same as before, pasting for reproducibility: test2.py#!/usr/bin/env python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import re
def main():
with torch.no_grad():
model_path = '/home/shannon/llama.cpp/tencent_Hunyuan-A13B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
messages = [
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt",
enable_thinking=True # Toggle thinking mode (default: True)
)
outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=5000)
output_text = tokenizer.decode(outputs[0])
print(outputs)
print(output_text)
if __name__ == '__main__':
main() The output looks normal to me and it answered the prompt. It does look like to me it works. CPU-only, 256GB Hetzner server. |
Here is the PPL with the pretrain model from https://huggingface.co/tencent/Hunyuan-A13B-Pretrain: make -j && ./bin/llama-perplexity -m ../models/hunyuan-a13b-pt/ggml-model-q8_0.gguf -f wikitext-2-raw/wiki.test.raw -fa
Final estimate: PPL = 5.2861 +/- 0.03234
@ngxson Does the logits match if this new expert router algorithm is disabled in the reference implementation? |
I'm running the polyglot aider test with Q8 gguf from Bullerwins. Its not passing any tests. the responses are well formed with thinking ON but with thinking OFF it just misses everything. - dirname: 2025-07-03-08-17-48--Hunyuan-A13B-Instruct-q8_0-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great Work,
I've pull the code test on a X86 CPU server, the fp16 and int8 inference is work, but seems result not quite accurate as the running on vLLM.
Just give some comments about model version, and also the chat template.
@@ -6436,6 +6439,155 @@ def set_gguf_parameters(self): | |||
super().set_gguf_parameters() | |||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) | |||
|
|||
|
|||
@ModelBase.register("HunYuanMoEV1ForCausalLM") | |||
class HunYuanMoEModel(TextModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you align with Hunyuan's naming , with version V1 suffix?
|
||
@ModelBase.register("HunYuanMoEV1ForCausalLM") | ||
class HunYuanMoEModel(TextModel): | ||
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also could you add the version suffix on the arch name, like the arch name in model 's config.json ?
@@ -656,6 +657,7 @@ class MODEL_TENSOR(IntEnum): | |||
MODEL_ARCH.DOTS1: "dots1", | |||
MODEL_ARCH.ARCEE: "arcee", | |||
MODEL_ARCH.ERNIE4_5: "ernie4_5", | |||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hunyuan-moe-v1 will be a better name for later model updating.
@@ -117,6 +117,7 @@ extern "C" { | |||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, | |||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, | |||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, | |||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a version suffix on vocab type will be better.
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { | |||
{ LLM_ARCH_DOTS1, "dots1" }, | |||
{ LLM_ARCH_ARCEE, "arcee" }, | |||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" }, | |||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also .
@@ -665,6 +668,21 @@ int32_t llm_chat_apply_template( | |||
if (add_ass) { | |||
ss << "<|response|>"; | |||
} | |||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { | |||
// tencent/Hunyuan-A13B-Instruct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the chat template of hunyuan a13b shoule be a much complex one ? with a quick and slow think option.
also the model default enable the slow think,
does llama cpp have some option on enable_think like the huggingface exmaple ?
@@ -1656,6 +1657,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { | |||
tokenizer_pre == "seed-coder") { | |||
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; | |||
clean_spaces = false; | |||
} else if ( | |||
tokenizer_pre == "hunyuan") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tokenizer verison
@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: | |||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": | |||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 | |||
res = "minerva-7b" | |||
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": | |||
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct | |||
res = "hunyuan" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model name better with hunyuan A13B
@@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum): | |||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, | |||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, | |||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, | |||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model name should be hunyuan a13b, from my source , they will release more llm model soon, we'd better add some identify for the mdoel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is tokenizer name, not model name
Perfectly working commit, can you review and approve this, pls? @ggerganov 🥺🙏 |
I think the logits still have to be verified between the GGUF and the original model implementation (disabling the custom expert router mechanism) first. There hasn't been an update yet from @ngxson as to whether it does match. |
Based on community testing, these merges are coherent It's just a small improvement in the future to investigate the router block of code from #14425 (comment) I encourage merge based on the evidence so far. Great looking model. |
For the record, when I skimmed the vllm PR that added the "inference only" model code, it did not appear to implement the custom expert selection either. I would also vote to merge as is, unless someone with the time and hardware can do some deeper comparisons with vllm at f16. In the mean time, it's quite usable. |
Just adding my +1 for merge; I went and tested the latest code with Q6_K from bullerwins/Hunyuan-A13B-Instruct-GGUF:
On H100 it looks really nice in terms of speed:
Also llama-bench just for completeness:
Notes:
It looks good! |
Ok, that sounds like a good explanation.
@kooshi Earlier you said that the model behaves weird. Did something change? |
The weirdness I was seeing may have been from my settings, or perhaps inherent to the model. It was quite smart, just stumbled over its own Edit: thinking back, I was running it with |
@kooshi In my testing, it's extremely sensitive to sampling. The model is both very prone to loop, very sensitive to prompt formatting, yet "uncertain" about its own think formatting. It's also multiple tokens (eg not a single token), which gives it more opportunity to 'mess up.' A relatively high MinP seems to help it behave. But the default sampling in some UIs would definitely trip it up. |
* origin/master: model : fix hunyuan moe chat template (ggml-org#14584) model : add SmolLM3 (ggml-org#14581) memory : fix broken batch splits for recurrent cache (ggml-org#14575) vulkan : fix rope with partial rotation and non-cont src (ggml-org#14582) server: Add ability to mount server at prefix (ggml-org#14544) model : add hunyuan moe (ggml-org#14425) vulkan: increase timeout for CI (ggml-org#14574) cuda : fix rope with partial rotation and non-cont src (ggml-org#14580) CUDA: add bilinear interpolation for upscale (ggml-org#14563) musa: fix build warnings (unused variable) (ggml-org#14561) llama : fix incorrect minicpm3 v_states shape (ggml-org#14571) llama : remove ggml_cont where possible (ggml-org#14568)
This model is broken for me. I converted the HF weights to GGUF this morning after the PR was merged and made a fresh Q4_K_M quantization. I'm getting lots of broken output and, as @Downtown-Case mentioned, the model doesn't seem to know how to format its own messages. It will close and open the |
Oh let me see. I'll try that now. |
It's working! I fed the model the entire
|
Re perplexity values - I'm getting PPL increasing from 1 to 3 to 5 to 7 to 27 to 37 and now 227 :( |
This sounds not good, do you have apply this MR when testing? Seems the chat template should be fixed in this PR. |
@kzjeef Yes I recompiled from source - I'll see how high the PPL goes - I'll still try to make some quants! |
The tested PPL has been absurdly high in every test of the Instruct model, including the official implementation, despite it being coherent in chats. The base model gives a perfectly reasonable score: #14425 (comment) If anyone can verify what it's actually predicting it might help (probably trying to start with I hope it doesn't get in the way of the heuristics for the dynamic quants. I always look forward to them. |
About the reason parser, what's the location of llama.cpp ? i'm working on vllm's reason parser(vllm-project/vllm#20625), maybe some one or myself can porting this to llama.cpp Actually we have tested some complex math case internally after this PR: #14584
|
So my imatrix gets However I think as someone mentioned it's due to Quants at https://huggingface.co/unsloth/Hunyuan-A13B-Instruct-GGUF Usage:
|
Probably the easiest way to see what is causing this is to start generation from a single EDIT: I think this also shows it might be time to consider letting |
* model : add hunyuan moe * tokenizer ok * fix tensor name * cgraph init * chat template * wip * almost working * skip embed, fix bos * cleanup * yarn scaling * cleanup * correct rope type * failed token fix * ntk alpha freq_base * tokenization working * cleanup and pr changes * vocab_size sanity check * ntk alpha generic * Update convert_hf_to_gguf.py * Apply suggestions from code review * fix regression * fix style --------- Co-authored-by: kooshi <[email protected]>
* model : add hunyuan moe * tokenizer ok * fix tensor name * cgraph init * chat template * wip * almost working * skip embed, fix bos * cleanup * yarn scaling * cleanup * correct rope type * failed token fix * ntk alpha freq_base * tokenization working * cleanup and pr changes * vocab_size sanity check * ntk alpha generic * Update convert_hf_to_gguf.py * Apply suggestions from code review * fix regression * fix style --------- Co-authored-by: kooshi <[email protected]>
Fix #14415
TODO: