-
Notifications
You must be signed in to change notification settings - Fork 12.4k
Model : Add support for Kimi-K2 #14654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for the patch, I'm excited to try this model! Running convert_hf_to_gguf based on your branch, will report back with the results. |
Uploaded Q2_K and BF16 GGUF to HuggingFace (thanks @danielhanchen for the initial BF16 conversion!) https://huggingface.co/gabriellarson/Kimi-K2-Instruct-GGUF |
Nice work @gabriellarson ! I was also trying to get a PR going - master...unslothai:llama.cpp:master - but I was primarily stuck on the K2's special regex handling - will try yours out to see if the regex works! |
has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare |
i believe @gabriellarson made a Q2_K if you scroll down to the bottom of their huggingface repo files and that is a link to a discussion where some folks are sharing information. going to try this PR shortly and with luck release an imatrix file (or two eventually...) 🤞 |
@CISC I think this is ready for review now |
Thank you! The Q2 I made outputs complete gibberish - not sure if my own tweaks to the SRC caused this issue or if its inherent to the commit as a whole right now. I saw the recent commits so I will start over from a clean slate. I need to test some more to see if FA or KV quant was the issue. Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek. |
I've just started converting -> testing (based on my own dequantized BF16). |
I haven't tested any edge-cases involving complex character handling in the patch, but it works for English:
|
Thanks @anikifoss for the confirmation. I too have been able to fp8 to bf16 cast, then run convert_hf_to_gguf to get bf16 GGUF and use that to quantize a "pure" q8_0 which successfully inferenced with llama-server in a few short chats. I have my methodology details and screenshots on the hf repo discussion and updating as I go along. Running into issue now with imatrix dropping a lot of experts due to only 99.74% partial data, might need to look at #9400 (comment) to get a better imatrix here on mainline. |
@thad0ctor How many layers did you have to offload to vram and at what context size did you get that speed? |
I tested briefly with the provided Q2_K quant and observed a lot of repetition in the output. Whole paragraphs repeating verbatim 3-4 times under different but similar headings. Bottom Line, In Summary, etc. (temp=0.3 and temp=0.6). I'm building Q8 and will try again. |
@gabriellarson I can confirm the new regex seems to work well based on tokenization ID matches! @ubergarm Yes you're correct on experts being zeros - I think I also found this to be the case. I also made some 245GB, 281GB (IQ1_S) dynamic quants + Q2_K_XL, Q4_K_XL quants at https://huggingface.co/unsloth/Kimi-K2-Instruct-GGUF - it should work fine with this PR or using my fork https://github.com/unslothai/llama.cpp - guide to run them here: https://docs.unsloth.ai/basics/kimi-k2-how-to-run-locally#run-kimi-k2-tutorials |
I am pretty sure it was default context (2k) and 5 layers per card which left about 4-5 GB headroom on each card (Q_4 kv cache). I suspect that the llama.cpp splitting may be inefficiently mapping layers so I suspect that moe enhancements on ik_llama or -ot args may add some perfomance once this model gets flushed out a bit more It's unfortunately the Kimi team didn't work with the community pre-release to get the ball rolling on compatability with common interface engines for those of us who aren't swimming in VRAM lol |
Okay thanks for confirming! I checked your hf repo but didn't see your I'm trying @compilade 's new
Oh hey we already discussed this, but looks like your scripts mistakenly named another quant TQ1_0 given it doesn't actually contain that quantization type and is conflating the rough BPW range with an actual ternary model only quantization type. Its great there are a lot of options in the smaller size ranges these days, but just trying to keep the naming conventions accurate! Thanks and great job getting this beast of a model going! |
Original patch by @gabriellarson: ggml-org/llama.cpp#14654
@AesSedai I haven't tried using -fa with this model yet, perhaps it doesn't work well with flashattention and cache quantization? |
@gabriellarson Oh, excellent point. I'll disable that and try again. Edit: actually, unquanted fp16 cache somehow made it worse on the unsloth fork: ![]() trying the build from this PR again now. Edit 2: unquantized cache in both this PR build + the unsloth llama.cpp fork just outputs |
@AesSedai I tried your example with UD-Q2_K_XL via: ./llama.cpp/llama-cli --model unsloth/Kimi-K2-Instruct-GGUF/UD-Q2_K_XL/Kimi-K2-Instruct-UD-Q2_K_XL-00001-of-00008.gguf \
-ngl 99 --temp 0.6 --min-p 0.01 --jinja --seed 3407 -fa \
-sys "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." Prompt:
Output:
Is the output correct? |
@danielhanchen That looks right. I'm going to download the Q2_K_XL and give that a whirl. Still trying to figure out what's not working on my end, may try testing with llama-cli instead. |
@gabriellarson Thanks again for your effort on this. I got slowed down but hope to have some test quants made using your work on ik_llama.cpp. Sorry I am not sure how to cherry-pick the commits so your name shows up right. Also, it seems that while Kimi-K2 is close enough to deepseek to run, it has a different unique looking chat template as pointed out to me by some folks on the BeaverAI Club discord:
Kimi-K2-Instruct/blob/main/tokenizer_config.json#L154 Not sure if it even needs to be added given the gguf has the template psure, but the recent Hunyuan-A13B did add code in that area fwiw. Just in case you do need something like that, it would go into llama-chat.cpp psure. Here is the draft chat template code I still need to test. Thanks, will keep you posted how it goes tomorow after I finally have some test quants 🤞 |
@danielhanchen it looks like that @ubergarm kindly hosted a Q8_0 quant using ik_llama.cpp and I was able to use my existing frontend + template and got a coherent, non-mixed language response out of it so I'm sure that my text completion template is sane now. From ubergarm's setup the response looks good (it hits the 512 token response limit I set, so the cut-off is fine) I've also seen a few other weird / odd things related to cache quantization. For the following, I've re-generated the reply from the test prompt I gave previously, with the settings I know were working in ubergarm's setup. The following examples are all from the unsloth Q4_K_XL gguf:
![]()
![]()
![]()
![]()
![]()
![]()
![]() Basically, I think there's SOMETHING going on with either the PR, or the unsloth Q4_K_XL gguf and it's affected by the cache quantization levels. I'm downloading the unsloth Q2_K_XL to see if it's some gguf quantization fluke idk |
@AesSedai it might just be because of the instruct tuning. I'd imagine K2-Base would work better for text completion |
I'm think I'm actually noticing way fewer issues when compiling w/o CUDA and not using my two 3090s. It still throws some Chinese characters in rarely, but the replies are like 99% English now. Edit: Nevermind, that was alright for the one shot assistant message above but for longer multiturn, it still looks funky. Guess I'll wait for ubergarm's quants and see if the same thing happens with ik_llama on my system. |
@AesSedai I tried Q4_K_XL again with ./llama.cpp/llama-cli --model \
Kimi-K2-Instruct-GGUF/UD-Q4_K_XL/Kimi-K2-Instruct-UD-Q4_K_XL-00001-of-00013.gguf \
-ngl 99 --temp 0.6 --jinja --min-p 0.01 -no-cnv \
--prompt "<|im_system|>system<|im_middle|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions.<|im_end|><|im_user|>user<|im_middle|>Blank: I\'m (Aes) having a discussion on discord with a friend (Drummer) about GPU heat and power consumption. I\'ve given a brief analogy for power limiting vs undervolting, but it\'s not great.\n\n```\n[10:31 AM] Drummer[B.AI] : Is power limit like a software level thing\n[10:31 AM] Drummer[B.AI] : I don\'t get it\n\n[10:42 AM] Aes Sedai: Power has volts and watts and some fancy math to connect the two in a few different ways. But you might be able to think of it roughly like:\n\nPower limit: decreasing the maximum wattage a card is allowed to use (think volume, like a fire hydrant versus garden hose). Limiting wattage reduces heat because it\'s pushing less electricity through.\n\nUndervolting: this is a bit more abstract, but high voltage leads to higher heat and can induce instability (so can low voltage though, there\'s a narrow band for which voltage works). To re-use the water metaphor, this is how "cold" the water is, sort of like a measure of effectiveness. Undervolting uses like room-temp water instead of cold glacier water (normal voltage), but the tradeoff is that it requires less work to produce. So overall the card generates less heat.\n\nBetween the two, they are like different axis that can be used to reduce power draw and heat, power limiting is simpler to do, but undervolting is more effective.\n\n[10:43 AM] Drummer[B.AI]: ChatGPT pls simplify\n[10:43 AM] Aes Sedai: To really get into it also you\'d have to discuss clock speeds at mV, because that\'s how you tune undervolting\n```\n\nCould you help simplify it?<|im_end|><|im_assistant|>assistant<|im_middle|>Assistant:" and I get:
I also added
|
@RodriMora Yes I saw that as well!
|
I'm getting stuck on adding the template to llm_chat_detect_template() in llama-chat.cpp |
|
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
I just got it working with this, my first attempt forgot to
|
@ubergarm I'll add the i also include the tool role in mine, is the tool role not necessary? |
I wasn't sure myself honestly, but yes yours does look correct to me given my understanding of the official template. Should be fine, but I don't have a quant to test it and honestly don't know how to use proper tool calling 😅 👈 chat template decoder scriptoutput$ python chat_template_tester.py moonshotai/Kimi-K2-Instruct
>> chat template <<
<|im_system|>system<|im_middle|>example system prompt<|im_end|><|im_user|>user<|im_middle|>example user turn 1<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 1<|im_end|><|im_user|>user<|im_middle|>example user turn 2<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 2<|im_end|><|im_system|>tool<|im_middle|>## Return of \nsome kind of tool call maybe<|im_end|><|im_assistant|>assistant<|im_middle|> python script$ cat chat_template_tester.py
# uv pip install transformers jinja2
# (and sometimes also sentencepiece torch statsmodels, looking at you ERNIE4.5)
from transformers import AutoTokenizer
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("model", help="Name of Hugging Face LLM repo (org/model format)")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code = True)
chat = [
{"role": "system", "content": "example system prompt"},
{"role": "user", "content": "example user turn 1"},
{"role": "assistant", "content": "example assistant turn 1"},
{"role": "user", "content": "example user turn 2"},
{"role": "assistant", "content": "example assistant turn 2"},
{"role": "tool", "content": "some kind of tool call maybe"},
]
print(">> chat template <<")
print(tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False))
print(">> end of chat template <<") |
I used the same set_vocab approach as the HunYuanMoE, and attempted to accurately represent the kimi_tokenization.py regex in unicode.cpp .
I haven't converted to GGUF yet because this model is so dang huge, and I would appreciate some feedback before I try to convert.