diff --git a/README.md b/README.md
index 604bbdcc..63c418fc 100644
--- a/README.md
+++ b/README.md
@@ -33,12 +33,19 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri).
## What's New
-* [2025.05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer to [DeepSeek-R1-0528-int2-mixed-sym-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-sym-inc), [DeepSeek-R1-0528-int4-sym-gptq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-gptq-inc-auto-round) and [DeepSeek-R1-0528-int4-asym-awq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-awq-inc-auto-round) for more details.
+
+* [2025.05] AutoRound now offers experimental support for the widely used **GGUF** format. We currently recommend using
+ RTN mode (--iters 0) for all the bits except 3. A more advanced algorithm tailored for some specific configurations is likely
+ to be introduced in the upcoming release. Example models are
+ available on the Intel Hugging Face space, including
+ [Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound-inc-v0](https://huggingface.co/Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound-inc-v0)
+ and [Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound-inc-v0](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound-inc-v0)
+* [2025.05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer
+ to [DeepSeek-R1-0528-int2-mixed-sym-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-sym-inc), [DeepSeek-R1-0528-int4-sym-gptq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-gptq-inc-auto-round)
+ and [DeepSeek-R1-0528-int4-asym-awq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-awq-inc-auto-round) for
+ more details.
* [2025/05] AutoRound has been integrated into **vLLM**. You can now run models in the AutoRound format directly with
vLLM versions later than v0.85.post1.
-* [2025/04] AutoRound provides some recipes for **Qwen3** series, please refer
- to [Qwen3-8B-sym-recipe](https://huggingface.co/Intel/Qwen3-8B-int4-AutoRound-inc) and [Qwen3-14B-sym-recipe](https://huggingface.co/Intel/Qwen3-14B-int4-AutoRound-inc) for
- more details.
* [2025/04] AutoRound has been integrated into **Transformers**. You can run models in the AutoRound format directly
with
Transformers versions later than 4.51.3.
@@ -277,9 +284,12 @@ models. Besides, recently 3 bits may have some accuracy issues in Transformers.
**AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely
adopted within the community, **only 4-bits quantization is supported**.
-**llmcompressor Format**: This format is for reusing llmcompressor format, **only INT8 W8A8 dynamic quantization is supported**.
+**llmcompressor Format**: This format is for reusing llmcompressor format, **only INT8 W8A8 dynamic quantization is
+supported**.
+
+**GGUF** Format: Experimental feature. This format is well-suited for CPU devices and is widely adopted by the
+community.
-**GGUF** Format: Experimental feature. This format is well-suited for CPU devices and is widely adopted by the community.
### Quantization Costs
Testing was conducted on the Nvidia A100 80G using the nightly version of PyTorch 2.6.0.dev20241029+cu124. Please note
@@ -340,18 +350,18 @@ Triton, but the final choice depends on factors such as bits, group_size, packin
backend may not always be the most suitable for certain devices. Please refer
to the following table for the details and specify the backend you want.
-| Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements |
-|--------------------------------------|--------------|---------|-----------|----------|-----------------|----------------------------------|
-| ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch |
+| Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements |
+|--------------------------------------|--------------|---------|-----------|----------|-----------------|---------------------------------------|
+| ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch |
| itrex | cpu | 2,4,8 | BF16/FP16 | 1 | gptq_zp+-1/awq |
intel-extension-for-transformers |
-| marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel |
-| exllamav2 or
gptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel |
-| exllamav2 or
gptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq |
-| gptq:cuda | cuda | 2,3,4,8 | FP16 | 1 | gptq_zp+-1 | auto-gptq |
-| triton | xpu/cuda | 2,4,8 | BF16/FP16 | 2 | gptq/gptq_zp+-1 | auto-round |
-| awq | cuda | 4 | FP16 | 5 | awq | auto-awq |
-| hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round |
-| torch | xpu/cpu/cuda | 2,3,4,8 | BF16/FP16 | 0 | gptq/gptq_zp+-1 | auto-round |
+| marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel |
+| exllamav2 or
gptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel |
+| exllamav2 or
gptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq |
+| gptq:cuda | cuda | 2,3,4,8 | FP16 | 1 | gptq_zp+-1 | auto-gptq |
+| triton | xpu/cuda | 2,4,8 | BF16/FP16 | 2 | gptq/gptq_zp+-1 | auto-round |
+| awq | cuda | 4 | FP16 | 5 | awq | auto-awq |
+| hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round |
+| torch | xpu/cpu/cuda | 2,3,4,8 | BF16/FP16 | 0 | gptq/gptq_zp+-1 | auto-round |
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
diff --git a/auto_round/autoround.py b/auto_round/autoround.py
index ed38fc5f..51bffe93 100644
--- a/auto_round/autoround.py
+++ b/auto_round/autoround.py
@@ -716,6 +716,7 @@ def quantize_embedding_layer(self):
Returns:
bool: True if the quantization process completes without critical errors.
"""
+ is_quantized = False
for name, module in self.model.named_modules():
# Skip non-Embedding modules or layers not in config
if not isinstance(module, torch.nn.Embedding) or name not in self.layer_config:
@@ -726,7 +727,7 @@ def quantize_embedding_layer(self):
# Skip layers that are not marked for quantization
if not check_to_quantized(config):
continue
-
+ is_quantized = True
config["scale_dtype"] = self.scale_dtype
dtype = config["data_type"]
@@ -779,7 +780,7 @@ def quantize_embedding_layer(self):
# Release memory
clear_memory()
- return True
+ return is_quantized
def quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> None:
"""Performs RTN quantization using input activation statistics (imatrix).
diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py
index 209909e3..9225dbae 100644
--- a/auto_round/calib_dataset.py
+++ b/auto_round/calib_dataset.py
@@ -16,7 +16,7 @@
import random
import torch
-from datasets import Dataset, IterableDataset
+from datasets import Dataset, IterableDataset, load_dataset, concatenate_datasets
from datasets import Features, Sequence, Value
from torch.utils.data import DataLoader
import sys
@@ -38,7 +38,12 @@ def register_dataset(name):
"""
def register(dataset):
- CALIB_DATASETS[name] = dataset
+ if isinstance(name, list):
+ names = name
+ else:
+ names = [name]
+ for tmp_name in names:
+ CALIB_DATASETS[tmp_name] = dataset
return dataset
return register
@@ -46,20 +51,22 @@ def register(dataset):
def apply_chat_template_to_samples(samples, tokenizer, seqlen, system_prompt=None):
rendered_messages = []
- if system_prompt is None:
- system_prompt = "You are a helpful assistant."
+ # if system_prompt is None: ## remove system prompt as models like deepseek don't recommend using it
+ # system_prompt = "You are a helpful assistant."
for text in samples:
- if system_prompt == "":
- message = [{"role": "user", "content": text}]
+ message = []
+ if system_prompt is not None and system_prompt != "":
+ message.append({"role": "system", "content": system_prompt})
+
+ if isinstance(text, list) and isinstance(text[0], dict):
+ message += text
else:
- message = [{"role": "system", "content": system_prompt},
- {"role": "user", "content": text}]
+ message.append({"role": "user", "content": text})
try:
chat_templated = tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
-
)
except:
logger.warning(
@@ -100,7 +107,7 @@ def default_tokenizer_function(examples):
return default_tokenizer_function
-@register_dataset("NeelNanda/pile-10k")
+@register_dataset(["NeelNanda/pile-10k", "pile-10k"])
def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42,
apply_chat_template=False, system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
@@ -123,7 +130,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
try:
- calib_dataset = load_dataset(dataset_name, split=split)
+ calib_dataset = load_dataset("NeelNanda/pile-10k", split=split)
except Exception as e:
import ssl
error_message = str(e)
@@ -141,7 +148,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split
return calib_dataset
-@register_dataset("swift/pile-val-backup")
+@register_dataset(["swift/pile-val-backup", "pile-val-backup"])
def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup", split=None, seed=42,
apply_chat_template=False, system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
@@ -168,14 +175,13 @@ def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup"
from modelscope import MsDataset # pylint: disable=E0401
calib_dataset = MsDataset.load('swift/pile-val-backup',
'default', split=split).to_iterable_dataset() # , use_streaming=True
- calib_dataset = calib_dataset.take(10000)
- calib_dataset = calib_dataset.shuffle(seed=seed)
+ calib_dataset = calib_dataset.shuffle(seed=seed).take(10000)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
return calib_dataset
-@register_dataset("BAAI/CCI3-HQ")
+@register_dataset(["BAAI/CCI3-HQ", "CCI3-HQ"])
def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False,
system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
@@ -196,15 +202,14 @@ def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=No
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
- calib_dataset = load_dataset(dataset_name, split='train', streaming=True)
- calib_dataset = calib_dataset.take(10000)
- calib_dataset = calib_dataset.shuffle(seed=seed)
+ calib_dataset = load_dataset("BAAI/CCI3-HQ", split='train', streaming=True)
+ calib_dataset = calib_dataset.shuffle(seed=seed).take(10000)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
return calib_dataset
-@register_dataset("codeparrot/github-code-clean")
+@register_dataset(["codeparrot/github-code-clean", "github-code-clean"])
def get_github_code_clean_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42,
apply_chat_template=False, system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
@@ -243,19 +248,115 @@ def default_tokenizer_function(examples):
return default_tokenizer_function
- from datasets import load_dataset
-
tokenizer_function = get_default_tokenizer_function()
- calib_dataset = load_dataset(dataset_name, split='train', streaming=True, trust_remote_code=True)
- calib_dataset = calib_dataset.take(10000)
- calib_dataset = calib_dataset.shuffle(seed=seed)
+ dataset_mit = load_dataset("codeparrot/github-code-clean", "all-mit", split='train',
+ streaming=True, trust_remote_code=True).shuffle(seed=seed)
+ dataset_apache = load_dataset("codeparrot/github-code-clean", "all-apache-2.0", split='train',
+ streaming=True, trust_remote_code=True).shuffle(seed=seed)
+ calib_dataset = concatenate_datasets([dataset_mit, dataset_apache])
+ calib_dataset = calib_dataset.shuffle(seed=seed).take(10000) ##TODO concat data'shuffle may have bugs
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
return calib_dataset
-@register_dataset("madao33/new-title-chinese")
+@register_dataset(["HuggingFaceH4/ultrachat_200k", "ultrachat_200k"])
+def get_ultrachat_dataset(
+ tokenizer,
+ seqlen,
+ dataset_name="HuggingFaceH4/ultrachat_200k",
+ split=None,
+ seed=42,
+ apply_chat_template=True,
+ system_prompt=None,
+):
+ if split is None:
+ split = "train_sft"
+ all_splits = ["train_sft", "test_sft", "train_gen", "test_gen"]
+ if split not in all_splits:
+ raise ValueError("split must be one of {} for ultrachat_200k ".format(all_splits))
+
+ dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split=split,
+ streaming=True, trust_remote_code=True)
+ dataset = dataset.shuffle(seed=seed).take(20000)
+
+ def is_instruct_tokenizer(tokenizer):
+ try:
+ out = tokenizer.apply_chat_template([{"role": "user", "content": "Hi"}])
+ return bool(out and len(out) > 0)
+ except Exception:
+ return False
+
+ is_instruct = is_instruct_tokenizer(tokenizer)
+
+ if is_instruct and not apply_chat_template:
+ logger.info("Tokenizer looks like an instruct/chat model, but apply_chat_template=False. Setting to True.")
+ apply_chat_template = True
+ elif not is_instruct and apply_chat_template:
+ logger.info("Tokenizer is not an instruct/chat model, but apply_chat_template=True. Setting to False.")
+ apply_chat_template = False
+
+ def tokenize_example_batch(examples):
+ if not apply_chat_template:
+ texts = []
+ for message_list in examples["messages"]:
+ combined = "".join([msg["content"] for msg in message_list])
+ texts.append(combined)
+ return tokenizer(texts, truncation=True, max_length=seqlen)
+ else:
+ return apply_chat_template_to_samples(
+ examples["messages"], tokenizer, seqlen, system_prompt=system_prompt
+ )
+
+ dataset = dataset.map(tokenize_example_batch, batched=True)
+ return dataset
+
+
+@register_dataset(["openbmb/Ultra-FineWeb", "openbmb/Ultra-FineWeb"])
+def get_ultrafinweb_dataset(
+ tokenizer,
+ seqlen,
+ dataset_name="openbmb/Ultra-FineWeb",
+ split=None,
+ seed=42,
+ apply_chat_template=True,
+ system_prompt=None,
+):
+ if split is not None:
+ if split not in ["en", "zh"]:
+ raise ValueError("split must be one of ['en', 'zh'] for Ultra-FineWeb dataset")
+ calib_dataset = load_dataset("openbmb/Ultra-FineWeb", split=split,
+ streaming=True, trust_remote_code=True)
+ else:
+ calib_dataset = load_dataset("openbmb/Ultra-FineWeb", split='en',
+ streaming=True, trust_remote_code=True)
+ # dataset_ch = load_dataset("openbmb/Ultra-FineWeb", split='zh',
+ # streaming=True, trust_remote_code=True).shuffle(seed=seed).take(2000)
+
+ # calib_dataset = concatenate_datasets([dataset_en, dataset_ch]) ##concat dasetset could not shuffle
+
+
+ calib_dataset = calib_dataset.shuffle(seed=seed).take(20000)
+
+ def get_default_tokenizer_function():
+ def default_tokenizer_function(examples):
+ if not apply_chat_template:
+ example = tokenizer(examples["content"], truncation=True, max_length=seqlen)
+ else:
+ example = apply_chat_template_to_samples(examples["content"], tokenizer, seqlen,
+ system_prompt=system_prompt)
+ return example
+
+ return default_tokenizer_function
+
+ tokenizer_function = get_default_tokenizer_function()
+
+ dataset = calib_dataset.map(tokenizer_function, batched=True)
+ return dataset
+
+
+@register_dataset(["madao33/new-title-chinese", "new-title-chinese"])
def get_new_chinese_title_dataset(
tokenizer,
seqlen,
@@ -307,7 +408,7 @@ def default_tokenizer_function(examples):
tokenizer_function = get_tokenizer_function()
- calib_dataset = load_dataset(dataset_name, split=split)
+ calib_dataset = load_dataset("madao33/new-title-chinese", split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
@@ -697,4 +798,3 @@ def collate_batch(batch):
calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader
-
diff --git a/auto_round/utils.py b/auto_round/utils.py
index 45b98998..be55e658 100644
--- a/auto_round/utils.py
+++ b/auto_round/utils.py
@@ -39,7 +39,7 @@ def __init__(self):
self._support_format = (
"auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:gptqmodel",
"auto_round:auto_awq", "itrex", "itrex_xpu", "fake")
- self._gguf_format = tuple(GGUF_CONFIG.keys())
+ self._gguf_format = tuple(sorted(GGUF_CONFIG.keys()))
self._support_list = self._support_format + self._gguf_format
def __contains__(self, key):
diff --git a/docs/gguf_accuracy.md b/docs/gguf_accuracy.md
deleted file mode 100644
index 555e2e17..00000000
--- a/docs/gguf_accuracy.md
+++ /dev/null
@@ -1,43 +0,0 @@
-1 We evaluate all models using the `fake` format, as lm-eval reports inaccurate accuracy for real GGUF format
-
-
-lm-eval 0.48
-
-```bash
- lm-eval --model hf --model_args pretrained="./" --tasks mmlu,leaderboard_ifeval,leaderboard_mmlu_pro,gsm8k
- --batch_size 16
-```
-
-2 `lm-head` and `embedding` layers are not quantized in any of the following models.
-
-| Q4_K_S | Avg. | mmlu | mmlu_pro | if_eval | gsm8k |
-|---------------------------|------------|--------|----------|----------|--------|
-| Qwen2.5-7B-GGUF | 0.6366 | 0.7097 | 0.4385 | 0.61115 | 0.7870 |
-| Qwen2.5-7B-AutoRound | **0.6529** | 0.7137 | 0.4471 | 0.6373 | 0.8135 |
-| Llama-3.1-8B-GGUF | 0.5589 | 0.6609 | 0.3610 | 0.4949 | 0.7187 |
-| Llama-3.1-8B-AutoRound | **0.5666** | 0.6627 | 0.3648 | 0.49965 | 0.7392 |
-| Falcon3-7B-GGUF | 0.5179 | 0.6649 | 0.3607 | 0.3251 | 0.7210 |
-| Falcon3-7B-AutoRound | **0.5261** | 0.6706 | 0.3841 | 0.31445 | 0.7354 |
-| phi-4-GGUF | **0.5623** | 0.7648 | 0.5292 | 0.0590 | 0.8961 |
-| phi-4-AutoRound | 0.5588 | 0.7673 | 0.5239 | 0.05175 | 0.8923 |
-
-| Q3_K_S | Avg. | mmlu | mmlu_pro | if_eval | gsm8k |
-|---------------------------|------------|--------|----------|----------|--------|
-| Qwen2.5-7B-GGUF | 0.5939 | 0.6936 | 0.4062 | 0.57675 | 0.6990 |
-| Qwen2.5-7B-AutoRound | **0.6103** | 0.7002 | 0.4171 | 0.6194 | 0.7043 |
-| Llama-3.1-8B-GGUF | 0.4903 | 0.6050 | 0.3260 | 0.44265 | 0.5876 |
-| Llama-3.1-8B-AutoRound | **0.5511** | 0.6548 | 0.3533 | 0.4913 | 0.7051 |
-| Falcon3-7B-GGUF | 0.4905 | 0.6434 | 0.3439 | 0.2871 | 0.6876 |
-| Falcon3-7B-AutoRound | **0.5296** | 0.6520 | 0.3679 | 0.30745 | 0.7911 |
-| phi-4-GGUF | **0.5527** | 0.7590 | 0.5072 | 0.0802 | 0.8643 |
-| phi-4-AutoRound | 0.5523 | 0.7657 | 0.5124 | 0.0587 | 0.8726 |
-
-| Q2_K_S | Avg. | mmlu | mmlu_pro | if_eval | gsm8k |
-|---------------------------|------------|--------|----------|----------|--------|
-| Qwen2.5-7B-GGUF | 0.3942 | 0.5750 | 0.2701 | 0.4071 | 0.3245 |
-| Qwen2.5-7B-AutoRound | **0.5133** | 0.6384 | 0.3383 | 0.4714 | 0.6050 |
-| Falcon3-7B-GGUF | 0.1936 | 0.3491 | 0.1521 | 0.21615 | 0.0569 |
-| Falcon3-7B-AutoRound | **0.3817** | 0.5607 | 0.2625 | 0.28955 | 0.4139 |
-| phi-4-GGUF | 0.4438 | 0.6715 | 0.3807 | 0.0802 | 0.6429 |
-| phi-4-AutoRound | **0.5113** | 0.7107 | 0.4383 | 0.08675 | 0.8097 |
-
diff --git a/docs/step_by_step.md b/docs/step_by_step.md
index 876a9dac..12081356 100644
--- a/docs/step_by_step.md
+++ b/docs/step_by_step.md
@@ -17,12 +17,13 @@ pip install auto-round
The [NeelNanda/pile-10k](https://huggingface.co/datasets/NeelNanda/pile-10k) in huggingface is adopted as the default
calibration data and will be downloaded automatically from the datasets Hub. Other available datasets include:
-
- `swift/pile-val-backup` from modelscope for addressing HF network issue
- `BAAI/CCI3-HQ` for Chinese
- `codeparrot/github-code-clean` for code
+- `HuggingFaceH4/ultrachat_200k` for chat data
- `madao33/new-title-chinese` for Chinese
- `mbpp` for code
+- `openbmb/Ultra-FineWeb`
### Customized Dataset