Skip to content

support two more calib datasets and fix embedding layer bug #653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri).
<div align="left">

## What's New
* [2025.05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer to [DeepSeek-R1-0528-int2-mixed-sym-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-sym-inc), [DeepSeek-R1-0528-int4-sym-gptq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-gptq-inc-auto-round) and [DeepSeek-R1-0528-int4-asym-awq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-awq-inc-auto-round) for more details.

* [2025.05] AutoRound now offers experimental support for the widely used **GGUF** format. We currently recommend using
RTN mode (--iters 0) for all the bits except 3. A more advanced algorithm tailored for some specific configurations is likely
to be introduced in the upcoming release. Example models are
available on the Intel Hugging Face space, including
[Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound-inc-v0](https://huggingface.co/Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound-inc-v0)
and [Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound-inc-v0](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound-inc-v0)
* [2025.05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer
to [DeepSeek-R1-0528-int2-mixed-sym-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-sym-inc), [DeepSeek-R1-0528-int4-sym-gptq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-gptq-inc-auto-round)
and [DeepSeek-R1-0528-int4-asym-awq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-awq-inc-auto-round) for
more details.
* [2025/05] AutoRound has been integrated into **vLLM**. You can now run models in the AutoRound format directly with
vLLM versions later than v0.85.post1.
* [2025/04] AutoRound provides some recipes for **Qwen3** series, please refer
to [Qwen3-8B-sym-recipe](https://huggingface.co/Intel/Qwen3-8B-int4-AutoRound-inc) and [Qwen3-14B-sym-recipe](https://huggingface.co/Intel/Qwen3-14B-int4-AutoRound-inc) for
more details.
* [2025/04] AutoRound has been integrated into **Transformers**. You can run models in the AutoRound format directly
with
Transformers versions later than 4.51.3.
Expand Down Expand Up @@ -277,9 +284,12 @@ models. Besides, recently 3 bits may have some accuracy issues in Transformers.
**AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely
adopted within the community, **only 4-bits quantization is supported**.

**llmcompressor Format**: This format is for reusing llmcompressor format, **only INT8 W8A8 dynamic quantization is supported**.
**llmcompressor Format**: This format is for reusing llmcompressor format, **only INT8 W8A8 dynamic quantization is
supported**.

**GGUF** Format: Experimental feature. This format is well-suited for CPU devices and is widely adopted by the
community.

**GGUF** Format: Experimental feature. This format is well-suited for CPU devices and is widely adopted by the community.
### Quantization Costs

Testing was conducted on the Nvidia A100 80G using the nightly version of PyTorch 2.6.0.dev20241029+cu124. Please note
Expand Down Expand Up @@ -340,18 +350,18 @@ Triton, but the final choice depends on factors such as bits, group_size, packin
backend may not always be the most suitable for certain devices. Please refer
to the following table for the details and specify the backend you want.

| Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements |
|--------------------------------------|--------------|---------|-----------|----------|-----------------|----------------------------------|
| ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch |
| Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements |
|--------------------------------------|--------------|---------|-----------|----------|-----------------|---------------------------------------|
| ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch |
| itrex | cpu | 2,4,8 | BF16/FP16 | 1 | gptq_zp+-1/awq | <br/>intel-extension-for-transformers |
| marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel |
| exllamav2 or<br/>gptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel |
| exllamav2 or<br/>gptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq |
| gptq:cuda | cuda | 2,3,4,8 | FP16 | 1 | gptq_zp+-1 | auto-gptq |
| triton | xpu/cuda | 2,4,8 | BF16/FP16 | 2 | gptq/gptq_zp+-1 | auto-round |
| awq | cuda | 4 | FP16 | 5 | awq | auto-awq |
| hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round |
| torch | xpu/cpu/cuda | 2,3,4,8 | BF16/FP16 | 0 | gptq/gptq_zp+-1 | auto-round |
| marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel |
| exllamav2 or<br/>gptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel |
| exllamav2 or<br/>gptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq |
| gptq:cuda | cuda | 2,3,4,8 | FP16 | 1 | gptq_zp+-1 | auto-gptq |
| triton | xpu/cuda | 2,4,8 | BF16/FP16 | 2 | gptq/gptq_zp+-1 | auto-round |
| awq | cuda | 4 | FP16 | 5 | awq | auto-awq |
| hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round |
| torch | xpu/cpu/cuda | 2,3,4,8 | BF16/FP16 | 0 | gptq/gptq_zp+-1 | auto-round |

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down
5 changes: 3 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ def quantize_embedding_layer(self):
Returns:
bool: True if the quantization process completes without critical errors.
"""
is_quantized = False
for name, module in self.model.named_modules():
# Skip non-Embedding modules or layers not in config
if not isinstance(module, torch.nn.Embedding) or name not in self.layer_config:
Expand All @@ -726,7 +727,7 @@ def quantize_embedding_layer(self):
# Skip layers that are not marked for quantization
if not check_to_quantized(config):
continue

is_quantized = True
config["scale_dtype"] = self.scale_dtype
dtype = config["data_type"]

Expand Down Expand Up @@ -779,7 +780,7 @@ def quantize_embedding_layer(self):
# Release memory
clear_memory()

return True
return is_quantized

def quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> None:
"""Performs RTN quantization using input activation statistics (imatrix).
Expand Down
154 changes: 127 additions & 27 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random

import torch
from datasets import Dataset, IterableDataset
from datasets import Dataset, IterableDataset, load_dataset, concatenate_datasets
from datasets import Features, Sequence, Value
from torch.utils.data import DataLoader
import sys
Expand All @@ -38,28 +38,35 @@ def register_dataset(name):
"""

def register(dataset):
CALIB_DATASETS[name] = dataset
if isinstance(name, list):
names = name
else:
names = [name]
for tmp_name in names:
CALIB_DATASETS[tmp_name] = dataset
return dataset

return register


def apply_chat_template_to_samples(samples, tokenizer, seqlen, system_prompt=None):
rendered_messages = []
if system_prompt is None:
system_prompt = "You are a helpful assistant."
# if system_prompt is None: ## remove system prompt as models like deepseek don't recommend using it
# system_prompt = "You are a helpful assistant."
for text in samples:
if system_prompt == "":
message = [{"role": "user", "content": text}]
message = []
if system_prompt is not None and system_prompt != "":
message.append({"role": "system", "content": system_prompt})

if isinstance(text, list) and isinstance(text[0], dict):
message += text
else:
message = [{"role": "system", "content": system_prompt},
{"role": "user", "content": text}]
message.append({"role": "user", "content": text})
try:
chat_templated = tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,

)
except:
logger.warning(
Expand Down Expand Up @@ -100,7 +107,7 @@ def default_tokenizer_function(examples):
return default_tokenizer_function


@register_dataset("NeelNanda/pile-10k")
@register_dataset(["NeelNanda/pile-10k", "pile-10k"])
def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42,
apply_chat_template=False, system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
Expand All @@ -123,7 +130,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
try:
calib_dataset = load_dataset(dataset_name, split=split)
calib_dataset = load_dataset("NeelNanda/pile-10k", split=split)
except Exception as e:
import ssl
error_message = str(e)
Expand All @@ -141,7 +148,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split
return calib_dataset


@register_dataset("swift/pile-val-backup")
@register_dataset(["swift/pile-val-backup", "pile-val-backup"])
def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup", split=None, seed=42,
apply_chat_template=False, system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
Expand All @@ -168,14 +175,13 @@ def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup"
from modelscope import MsDataset # pylint: disable=E0401
calib_dataset = MsDataset.load('swift/pile-val-backup',
'default', split=split).to_iterable_dataset() # , use_streaming=True
calib_dataset = calib_dataset.take(10000)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.shuffle(seed=seed).take(10000)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)

return calib_dataset


@register_dataset("BAAI/CCI3-HQ")
@register_dataset(["BAAI/CCI3-HQ", "CCI3-HQ"])
def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False,
system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
Expand All @@ -196,15 +202,14 @@ def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=No
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template,
system_prompt=system_prompt)

calib_dataset = load_dataset(dataset_name, split='train', streaming=True)
calib_dataset = calib_dataset.take(10000)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = load_dataset("BAAI/CCI3-HQ", split='train', streaming=True)
calib_dataset = calib_dataset.shuffle(seed=seed).take(10000)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)

return calib_dataset


@register_dataset("codeparrot/github-code-clean")
@register_dataset(["codeparrot/github-code-clean", "github-code-clean"])
def get_github_code_clean_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42,
apply_chat_template=False, system_prompt=None):
"""Returns a dataloader for the specified dataset and split.
Expand Down Expand Up @@ -243,19 +248,115 @@ def default_tokenizer_function(examples):

return default_tokenizer_function

from datasets import load_dataset

tokenizer_function = get_default_tokenizer_function()

calib_dataset = load_dataset(dataset_name, split='train', streaming=True, trust_remote_code=True)
calib_dataset = calib_dataset.take(10000)
calib_dataset = calib_dataset.shuffle(seed=seed)
dataset_mit = load_dataset("codeparrot/github-code-clean", "all-mit", split='train',
streaming=True, trust_remote_code=True).shuffle(seed=seed)
dataset_apache = load_dataset("codeparrot/github-code-clean", "all-apache-2.0", split='train',
streaming=True, trust_remote_code=True).shuffle(seed=seed)
calib_dataset = concatenate_datasets([dataset_mit, dataset_apache])
calib_dataset = calib_dataset.shuffle(seed=seed).take(10000) ##TODO concat data'shuffle may have bugs
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)

return calib_dataset


@register_dataset("madao33/new-title-chinese")
@register_dataset(["HuggingFaceH4/ultrachat_200k", "ultrachat_200k"])
def get_ultrachat_dataset(
tokenizer,
seqlen,
dataset_name="HuggingFaceH4/ultrachat_200k",
split=None,
seed=42,
apply_chat_template=True,
system_prompt=None,
):
if split is None:
split = "train_sft"
all_splits = ["train_sft", "test_sft", "train_gen", "test_gen"]
if split not in all_splits:
raise ValueError("split must be one of {} for ultrachat_200k ".format(all_splits))

dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split=split,
streaming=True, trust_remote_code=True)
dataset = dataset.shuffle(seed=seed).take(20000)

def is_instruct_tokenizer(tokenizer):
try:
out = tokenizer.apply_chat_template([{"role": "user", "content": "Hi"}])
return bool(out and len(out) > 0)
except Exception:
return False

is_instruct = is_instruct_tokenizer(tokenizer)

if is_instruct and not apply_chat_template:
logger.info("Tokenizer looks like an instruct/chat model, but apply_chat_template=False. Setting to True.")
apply_chat_template = True
elif not is_instruct and apply_chat_template:
logger.info("Tokenizer is not an instruct/chat model, but apply_chat_template=True. Setting to False.")
apply_chat_template = False

def tokenize_example_batch(examples):
if not apply_chat_template:
texts = []
for message_list in examples["messages"]:
combined = "".join([msg["content"] for msg in message_list])
texts.append(combined)
return tokenizer(texts, truncation=True, max_length=seqlen)
else:
return apply_chat_template_to_samples(
examples["messages"], tokenizer, seqlen, system_prompt=system_prompt
)

dataset = dataset.map(tokenize_example_batch, batched=True)
return dataset


@register_dataset(["openbmb/Ultra-FineWeb", "openbmb/Ultra-FineWeb"])
def get_ultrafinweb_dataset(
tokenizer,
seqlen,
dataset_name="openbmb/Ultra-FineWeb",
split=None,
seed=42,
apply_chat_template=True,
system_prompt=None,
):
if split is not None:
if split not in ["en", "zh"]:
raise ValueError("split must be one of ['en', 'zh'] for Ultra-FineWeb dataset")
calib_dataset = load_dataset("openbmb/Ultra-FineWeb", split=split,
streaming=True, trust_remote_code=True)
else:
calib_dataset = load_dataset("openbmb/Ultra-FineWeb", split='en',
streaming=True, trust_remote_code=True)
# dataset_ch = load_dataset("openbmb/Ultra-FineWeb", split='zh',
# streaming=True, trust_remote_code=True).shuffle(seed=seed).take(2000)

# calib_dataset = concatenate_datasets([dataset_en, dataset_ch]) ##concat dasetset could not shuffle


calib_dataset = calib_dataset.shuffle(seed=seed).take(20000)

def get_default_tokenizer_function():
def default_tokenizer_function(examples):
if not apply_chat_template:
example = tokenizer(examples["content"], truncation=True, max_length=seqlen)
else:
example = apply_chat_template_to_samples(examples["content"], tokenizer, seqlen,
system_prompt=system_prompt)
return example

return default_tokenizer_function

tokenizer_function = get_default_tokenizer_function()

dataset = calib_dataset.map(tokenizer_function, batched=True)
return dataset


@register_dataset(["madao33/new-title-chinese", "new-title-chinese"])
def get_new_chinese_title_dataset(
tokenizer,
seqlen,
Expand Down Expand Up @@ -307,7 +408,7 @@ def default_tokenizer_function(examples):

tokenizer_function = get_tokenizer_function()

calib_dataset = load_dataset(dataset_name, split=split)
calib_dataset = load_dataset("madao33/new-title-chinese", split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)

Expand Down Expand Up @@ -697,4 +798,3 @@ def collate_batch(batch):

calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader

2 changes: 1 addition & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self):
self._support_format = (
"auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:gptqmodel",
"auto_round:auto_awq", "itrex", "itrex_xpu", "fake")
self._gguf_format = tuple(GGUF_CONFIG.keys())
self._gguf_format = tuple(sorted(GGUF_CONFIG.keys()))
self._support_list = self._support_format + self._gguf_format

def __contains__(self, key):
Expand Down
Loading
Loading