Skip to content

Model : Add support for Kimi-K2 #14654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 15, 2025
Merged

Model : Add support for Kimi-K2 #14654

merged 21 commits into from
Jul 15, 2025

Conversation

gabriellarson
Copy link
Contributor

I used the same set_vocab approach as the HunYuanMoE, and attempted to accurately represent the kimi_tokenization.py regex in unicode.cpp .

I haven't converted to GGUF yet because this model is so dang huge, and I would appreciate some feedback before I try to convert.

@github-actions github-actions bot added the python python script changes label Jul 12, 2025
@anikifoss
Copy link

Thanks for the patch, I'm excited to try this model! Running convert_hf_to_gguf based on your branch, will report back with the results.

@gabriellarson
Copy link
Contributor Author

gabriellarson commented Jul 13, 2025

Uploaded Q2_K and BF16 GGUF to HuggingFace (thanks @danielhanchen for the initial BF16 conversion!) https://huggingface.co/gabriellarson/Kimi-K2-Instruct-GGUF

@danielhanchen
Copy link
Contributor

Nice work @gabriellarson ! I was also trying to get a PR going - master...unslothai:llama.cpp:master - but I was primarily stuck on the K2's special regex handling - will try yours out to see if the regex works!

danielhanchen added a commit to unslothai/llama.cpp that referenced this pull request Jul 13, 2025
@thad0ctor
Copy link

has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare

@ubergarm
Copy link

ubergarm commented Jul 14, 2025

has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare

i believe @gabriellarson made a Q2_K if you scroll down to the bottom of their huggingface repo files and that is a link to a discussion where some folks are sharing information. going to try this PR shortly and with luck release an imatrix file (or two eventually...) 🤞

@gabriellarson
Copy link
Contributor Author

@CISC I think this is ready for review now

@thad0ctor
Copy link

has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare

i believe @gabriellarson made a Q2_K if you scroll down to the bottom of their huggingface repo files and that is a link to a discussion where some folks are sharing information. going to try this PR shortly and with luck release an imatrix file (or two eventually...) 🤞

Thank you! The Q2 I made outputs complete gibberish - not sure if my own tweaks to the SRC caused this issue or if its inherent to the commit as a whole right now. I saw the recent commits so I will start over from a clean slate. I need to test some more to see if FA or KV quant was the issue.

Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek.

@csabakecskemeti
Copy link
Contributor

I've just started converting -> testing (based on my own dequantized BF16).
Will report back (will take some time)

@anikifoss
Copy link

anikifoss commented Jul 14, 2025

I haven't tested any edge-cases involving complex character handling in the patch, but it works for English:

  • convert_hf_to_gguf worked
  • quantized model loads with llama.cpp compiled from this branch and the output looks good (able to one-shot the spinning hexagon with 20 balls)

@ubergarm
Copy link

Thanks @anikifoss for the confirmation. I too have been able to fp8 to bf16 cast, then run convert_hf_to_gguf to get bf16 GGUF and use that to quantize a "pure" q8_0 which successfully inferenced with llama-server in a few short chats.

I have my methodology details and screenshots on the hf repo discussion and updating as I go along.

Running into issue now with imatrix dropping a lot of experts due to only 99.74% partial data, might need to look at #9400 (comment) to get a better imatrix here on mainline.

@RodriMora
Copy link

Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek.

@thad0ctor How many layers did you have to offload to vram and at what context size did you get that speed?

@usrlocalben
Copy link

I tested briefly with the provided Q2_K quant and observed a lot of repetition in the output. Whole paragraphs repeating verbatim 3-4 times under different but similar headings. Bottom Line, In Summary, etc. (temp=0.3 and temp=0.6). I'm building Q8 and will try again.

@danielhanchen
Copy link
Contributor

@gabriellarson I can confirm the new regex seems to work well based on tokenization ID matches!

@ubergarm Yes you're correct on experts being zeros - I think I also found this to be the case.

I also made some 245GB, 281GB (IQ1_S) dynamic quants + Q2_K_XL, Q4_K_XL quants at https://huggingface.co/unsloth/Kimi-K2-Instruct-GGUF - it should work fine with this PR or using my fork https://github.com/unslothai/llama.cpp - guide to run them here: https://docs.unsloth.ai/basics/kimi-k2-how-to-run-locally#run-kimi-k2-tutorials

@thad0ctor
Copy link

Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek.

@thad0ctor How many layers did you have to offload to vram and at what context size did you get that speed?

I am pretty sure it was default context (2k) and 5 layers per card which left about 4-5 GB headroom on each card (Q_4 kv cache). I suspect that the llama.cpp splitting may be inefficiently mapping layers so I suspect that moe enhancements on ik_llama or -ot args may add some perfomance once this model gets flushed out a bit more

It's unfortunately the Kimi team didn't work with the community pre-release to get the ball rolling on compatability with common interface engines for those of us who aren't swimming in VRAM lol

@ubergarm
Copy link

@danielhanchen

Yes you're correct on experts being zeros - I think I also found this to be the case.

Okay thanks for confirming! I checked your hf repo but didn't see your Kimi-K2-Instruct-GGUF/imatrix_unsloth.dat there, perhaps you don't release them anymore it seems? Also given it is a .dat I assume it is missing data for a lot of exps.

I'm trying @compilade 's new imatrix.gguf which seems to go ahead and save even with partial data.

I also made some 245GB, 281GB (IQ1_S) dynamic quants + Q2_K_XL, Q4_K_XL quants

Oh hey we already discussed this, but looks like your scripts mistakenly named another quant TQ1_0 given it doesn't actually contain that quantization type and is conflating the rough BPW range with an actual ternary model only quantization type.

Its great there are a lot of options in the smaller size ranges these days, but just trying to keep the naming conventions accurate! Thanks and great job getting this beast of a model going!

anikifoss pushed a commit to anikifoss/ik_llama.cpp that referenced this pull request Jul 14, 2025
@gabriellarson
Copy link
Contributor Author

@AesSedai I haven't tried using -fa with this model yet, perhaps it doesn't work well with flashattention and cache quantization?

@AesSedai
Copy link

AesSedai commented Jul 14, 2025

@gabriellarson Oh, excellent point. I'll disable that and try again.

Edit: actually, unquanted fp16 cache somehow made it worse on the unsloth fork:

image

trying the build from this PR again now.

Edit 2: unquantized cache in both this PR build + the unsloth llama.cpp fork just outputs ucers on repeat 🤔

@danielhanchen
Copy link
Contributor

danielhanchen commented Jul 15, 2025

@AesSedai I tried your example with UD-Q2_K_XL via:

./llama.cpp/llama-cli --model unsloth/Kimi-K2-Instruct-GGUF/UD-Q2_K_XL/Kimi-K2-Instruct-UD-Q2_K_XL-00001-of-00008.gguf \
    -ngl 99 --temp 0.6 --min-p 0.01 --jinja --seed 3407 -fa \
    -sys "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."

Prompt:

systemA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
> Blank: I\'m (Aes) having a discussion on discord with a friend (Drummer) about GPU heat and power consumption. I\'ve given a brief analogy for power limiting vs undervolting, but it\'s not great.\n\n```\n[10:31 AM] Drummer[B.AI] : Is power limit like a software level thing\n[10:31 AM] Drummer[B.AI] : I don\'t get it\n\n[10:42 AM] Aes Sedai: Power has volts and watts and some fancy math to connect the two in a few different ways. But you might be able to think of it roughly like:\n\nPower limit: decreasing the maximum wattage a card is allowed to use (think volume, like a fire hydrant versus garden hose). Limiting wattage reduces heat because it\'s pushing less electricity through.\n\nUndervolting: this is a bit more abstract, but high voltage leads to higher heat and can induce instability (so can low voltage though, there\'s a narrow band for which voltage works). To re-use the water metaphor, this is how "cold" the water is, sort of like a measure of effectiveness. Undervolting uses like room-temp water instead of cold glacier water (normal voltage), but the tradeoff is that it requires less work to produce. So overall the card generates less heat.\n\nBetween the two, they are like different axis that can be used to reduce power draw and heat, power limiting is simpler to do, but undervolting is more effective.\\n[10:43 AM] Drummer[B.AI]: ChatGPT pls simplify\n[10:43 AM] Aes Sedai: To really get into it also you\'d have to discuss clock speeds at mV, because that\'s how you tune undervolting\n```\n\nCould you help simplify it?

Output:

Sure, here’s a super-simple version you can copy-paste:

---

Think of the GPU like a car.

- **Power-limit (PL)** = putting a speed-governor on the engine that says “you can’t go faster than 70 mph.”  
  It just caps the top speed (watts), so the car uses less fuel and stays cooler. Easy to do: one slider, job done.

- **Undervolting (UV)** = making the engine run on slightly less gasoline per mile while still driving the same speed.  
  You’re tuning the fuel/air mix (voltage) so the engine does the same work with less energy. It’s trickier—you might stall if you go too lean—but when dialed in it saves more fuel and heat than just limiting top speed.

You can use either, or both. PL is the quick “set max speed” knob; UV is the fine-tune “make the engine more efficient” knob.

---

That’s it—no hoses, no glaciers.

Is the output correct?

@AesSedai
Copy link

@danielhanchen That looks right. I'm going to download the Q2_K_XL and give that a whirl. Still trying to figure out what's not working on my end, may try testing with llama-cli instead.

@ubergarm
Copy link

ubergarm commented Jul 15, 2025

@gabriellarson Thanks again for your effort on this. I got slowed down but hope to have some test quants made using your work on ik_llama.cpp. Sorry I am not sure how to cherry-pick the commits so your name shows up right.

Also, it seems that while Kimi-K2 is close enough to deepseek to run, it has a different unique looking chat template as pointed out to me by some folks on the BeaverAI Club discord:

<|im_system|>system<|im_middle|>example system prompt<|im_end|><|im_user|>user<|im_middle|>example user turn 1<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 1<|im_end|><|im_user|>user<|im_middle|>example user turn 2<|im_end|><|im_assistant|>assistant<|im_middle|>

Kimi-K2-Instruct/blob/main/tokenizer_config.json#L154

Not sure if it even needs to be added given the gguf has the template psure, but the recent Hunyuan-A13B did add code in that area fwiw. Just in case you do need something like that, it would go into llama-chat.cpp psure. Here is the draft chat template code I still need to test.

Thanks, will keep you posted how it goes tomorow after I finally have some test quants 🤞

@AesSedai
Copy link

AesSedai commented Jul 15, 2025

@danielhanchen it looks like that llama-cli call is a chat completion endpoint instead of text completion, which is what I've been using.

@ubergarm kindly hosted a Q8_0 quant using ik_llama.cpp and I was able to use my existing frontend + template and got a coherent, non-mixed language response out of it so I'm sure that my text completion template is sane now.

From ubergarm's setup the response looks good (it hits the 512 token response limit I set, so the cut-off is fine)
image

I've also seen a few other weird / odd things related to cache quantization. For the following, I've re-generated the reply from the test prompt I gave previously, with the settings I know were working in ubergarm's setup. The following examples are all from the unsloth Q4_K_XL gguf:

  • FA off, FP16 K cache, FP16 V cache: ucersucersucersucersucersucersucersucers ad-infinitum (no EOS)
image
  • FA on, FP16 K cache, FP16 V cache: ucersucersucersucersucersucersucersucers ad-infinitum (no EOS)
image
  • FA on, Q8 K cache, FP16 V cache, some Chinese characters but not too wild
image
  • FA on, Q8 K cache, Q8 V cache, more Chinese characters
image
  • FA on, Q4 K cache, FP16 V cache, about the same amount but more dispersed:
image
  • FA on, Q4 K cache, Q8 V cache, about the same at above:
image
  • FA on, Q4 K cache, Q4 V cache, hit the word "transition" and exploded:
image

Basically, I think there's SOMETHING going on with either the PR, or the unsloth Q4_K_XL gguf and it's affected by the cache quantization levels.

I'm downloading the unsloth Q2_K_XL to see if it's some gguf quantization fluke idk

@gabriellarson
Copy link
Contributor Author

@AesSedai it might just be because of the instruct tuning. I'd imagine K2-Base would work better for text completion

@AesSedai
Copy link

AesSedai commented Jul 15, 2025

I'm think I'm actually noticing way fewer issues when compiling w/o CUDA and not using my two 3090s. It still throws some Chinese characters in rarely, but the replies are like 99% English now.

Edit: Nevermind, that was alright for the one shot assistant message above but for longer multiturn, it still looks funky. Guess I'll wait for ubergarm's quants and see if the same thing happens with ik_llama on my system.

@RodriMora
Copy link

RodriMora commented Jul 15, 2025

@danielhanchen
Copy link
Contributor

@AesSedai I tried Q4_K_XL again with -no-cnv --prompt via:

./llama.cpp/llama-cli --model \
    Kimi-K2-Instruct-GGUF/UD-Q4_K_XL/Kimi-K2-Instruct-UD-Q4_K_XL-00001-of-00013.gguf \
    -ngl 99 --temp 0.6 --jinja --min-p 0.01 -no-cnv \
    --prompt "<|im_system|>system<|im_middle|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions.<|im_end|><|im_user|>user<|im_middle|>Blank: I\'m (Aes) having a discussion on discord with a friend (Drummer) about GPU heat and power consumption. I\'ve given a brief analogy for power limiting vs undervolting, but it\'s not great.\n\n```\n[10:31 AM] Drummer[B.AI] : Is power limit like a software level thing\n[10:31 AM] Drummer[B.AI] : I don\'t get it\n\n[10:42 AM] Aes Sedai: Power has volts and watts and some fancy math to connect the two in a few different ways. But you might be able to think of it roughly like:\n\nPower limit: decreasing the maximum wattage a card is allowed to use (think volume, like a fire hydrant versus garden hose). Limiting wattage reduces heat because it\'s pushing less electricity through.\n\nUndervolting: this is a bit more abstract, but high voltage leads to higher heat and can induce instability (so can low voltage though, there\'s a narrow band for which voltage works). To re-use the water metaphor, this is how "cold" the water is, sort of like a measure of effectiveness. Undervolting uses like room-temp water instead of cold glacier water (normal voltage), but the tradeoff is that it requires less work to produce. So overall the card generates less heat.\n\nBetween the two, they are like different axis that can be used to reduce power draw and heat, power limiting is simpler to do, but undervolting is more effective.\n\n[10:43 AM] Drummer[B.AI]: ChatGPT pls simplify\n[10:43 AM] Aes Sedai: To really get into it also you\'d have to discuss clock speeds at mV, because that\'s how you tune undervolting\n```\n\nCould you help simplify it?<|im_end|><|im_assistant|>assistant<|im_middle|>Assistant:"

and I get:

Assistant: Let’s turn the GPU into a little car on a race-track.

1. Power-limiting = putting a speed-limit sign on the track.  
   No matter how hard the driver (GPU) wants to go, the car can’t exceed 100 km/h. This keeps engine heat and fuel use low, but the car may still be running a little rich (higher voltage) to guarantee it can always hit the new top speed safely.

2. Undervolting = teaching the driver to take the same corner at the same speed while pressing the accelerator less.  
   The car still reaches 100 km/h, but it burns less fuel and the engine runs cooler because the throttle isn’t open as far. If you push the pedal too little, the car stumbles (driver crashes = GPU crash), so you find the lowest pressure that still keeps the ride smooth.

One knob caps the top speed; the other teaches the engine to run more efficiently at whatever speed you do allow. [end of text]

I also added -fa --cache-type-k q4_1 --cache-type-v q4_1 and I get:

Assistant: Absolutely—let’s turn the messy, technical explanation into a 30-second “elevator version” that still feels right to a hardware-minded reader.

---

**Old analogy (paraphrased)**  
“Power-limiting is like a speed-governor on a car, and undervolting is like switching to a higher-octane fuel that lets you keep the same speed while burning less.”

---

**Problem with that version**  
- The governor part works, but “higher-octane fuel” is backwards: higher-octane resists knock, it doesn’t inherently save fuel.  
- Undervolting is really about *reducing* the “pressure” in the engine, not changing the fuel.  
- It doesn’t show *why* you can keep the same speed (frequency) with lower voltage.

---

**Clean, 30-second replacement**

Think of a GPU like a water pump pushing water through a hose:

1. **Power Limit (Watt-cap)** is the maximum flow-rate the pump is *allowed* to move—turn the dial down and the pump simply can’t push as much water, so the stream slows.  
2. **Undervolting** is switching to a *bigger hose* at the same pump speed. With less resistance, the same flow (performance) needs less effort (voltage), so the motor runs cooler and quieter without lowering the stream.

That’s it: same flow (fps), less work (watts), lower temperature.

---

If you want a one-liner for Discord:

“Power-limit caps how hard the pump *can* push; undervolting makes the hose wider so the pump hits the same speed with less pressure—less heat, same flow.” [end of text]

@danielhanchen
Copy link
Contributor

@RodriMora Yes I saw that as well!

  1. tokenizer.encode will now encode special tokens - this does NOT affect this PR see Feature Request: Support Kimi K2 #14642 (comment) which tokenizes special tokens correctly.
  2. The multi turn tool calling chat template does need an update. One has to enable the new one via --chat-template-file PATH_TO_KIMI_K2_CHAT_TEMPLATE.jinja - I will have to re-update the quants to bake the new chat template over the next few days

@gabriellarson
Copy link
Contributor Author

@gabriellarson Thanks again for your effort on this. I got slowed down but hope to have some test quants made using your work on ik_llama.cpp. Sorry I am not sure how to cherry-pick the commits so your name shows up right.

Also, it seems that while Kimi-K2 is close enough to deepseek to run, it has a different unique looking chat template as pointed out to me by some folks on the BeaverAI Club discord:

<|im_system|>system<|im_middle|>example system prompt<|im_end|><|im_user|>user<|im_middle|>example user turn 1<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 1<|im_end|><|im_user|>user<|im_middle|>example user turn 2<|im_end|><|im_assistant|>assistant<|im_middle|>

Kimi-K2-Instruct/blob/main/tokenizer_config.json#L154

Not sure if it even needs to be added given the gguf has the template psure, but the recent Hunyuan-A13B did add code in that area fwiw. Just in case you do need something like that, it would go into llama-chat.cpp psure. Here is the draft chat template code I still need to test.

Thanks, will keep you posted how it goes tomorow after I finally have some test quants 🤞

I'm getting stuck on adding the template to llm_chat_detect_template() in llama-chat.cpp
} else if (tmpl_contains("???")) { return LLM_CHAT_TEMPLATE_KIMI_K2; }
I'm not sure what exactly to put here to uniquely identify the kimi template
@ubergarm

@CISC
Copy link
Collaborator

CISC commented Jul 15, 2025

I'm getting stuck on adding the template to llm_chat_detect_template() in llama-chat.cpp } else if (tmpl_contains("???")) { return LLM_CHAT_TEMPLATE_KIMI_K2; } I'm not sure what exactly to put here to uniquely identify the kimi template

<|im_assistant|>assistant<|im_middle|> seems like a good one.

@gabriellarson gabriellarson requested a review from CISC July 15, 2025 16:27
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
@ubergarm
Copy link

ubergarm commented Jul 15, 2025

@gabriellarson @CISC

I just got it working with this, my first attempt forgot to add_ass, but without it the model would give "empty replies". I've had some folks testing with good success now using this code which would go into llama-chat.cpp as seen in this similar PR: ikawrakow/ik_llama.cpp#612 tested with this model ubergarm/Kimi-K2-Instruct-GGUF IQ2_KL 345.687 GiB (2.892 BPW) Final estimate: PPL = 3.2741 +/- 0.01689 (upload complete in 30 minutes lol)

+    } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
+        // moonshotai/Kimi-K2-Instruct
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|im_system|>system<|im_middle|>" << message->content << "<|im_end|>";
+            } else if (role == "assistant") {
+                ss << "<|im_user|>user<|im_middle|>" << message->content << "<|im_end|>";
+            } else {
+                ss << "<|im_assistant|>assistant<|im_middle|>" << message->content << "<|im_end|>";
+            }
+        }
+        if (add_ass) {
+            ss << "<|im_assistant|>assistant<|im_middle|>";
+        }

@gabriellarson
Copy link
Contributor Author

@ubergarm I'll add the if (add_ass){}.

i also include the tool role in mine, is the tool role not necessary?

@ubergarm
Copy link

@gabriellarson

i also include the tool role in mine, is the tool role not necessary?

I wasn't sure myself honestly, but yes yours does look correct to me given my understanding of the official template. Should be fine, but I don't have a quant to test it and honestly don't know how to use proper tool calling 😅

👈 chat template decoder script

output

$ python chat_template_tester.py moonshotai/Kimi-K2-Instruct
>> chat template <<
<|im_system|>system<|im_middle|>example system prompt<|im_end|><|im_user|>user<|im_middle|>example user turn 1<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 1<|im_end|><|im_user|>user<|im_middle|>example user turn 2<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 2<|im_end|><|im_system|>tool<|im_middle|>## Return of \nsome kind of tool call maybe<|im_end|><|im_assistant|>assistant<|im_middle|>

python script

$ cat chat_template_tester.py
# uv pip install transformers jinja2
# (and sometimes also sentencepiece torch statsmodels, looking at you ERNIE4.5)
from transformers import AutoTokenizer
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("model", help="Name of Hugging Face LLM repo (org/model format)")
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code = True)

chat = [
    {"role": "system", "content": "example system prompt"},
    {"role": "user", "content": "example user turn 1"},
    {"role": "assistant", "content": "example assistant turn 1"},
    {"role": "user", "content": "example user turn 2"},
    {"role": "assistant", "content": "example assistant turn 2"},
    {"role": "tool", "content": "some kind of tool call maybe"},
]

print(">> chat template <<")
print(tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False))
print(">> end of chat template <<")

@CISC CISC merged commit 4a4f426 into ggml-org:master Jul 15, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.