diff --git a/README.md b/README.md index 604bbdcc..63c418fc 100644 --- a/README.md +++ b/README.md @@ -33,12 +33,19 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri).
## What's New -* [2025.05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer to [DeepSeek-R1-0528-int2-mixed-sym-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-sym-inc), [DeepSeek-R1-0528-int4-sym-gptq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-gptq-inc-auto-round) and [DeepSeek-R1-0528-int4-asym-awq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-awq-inc-auto-round) for more details. + +* [2025.05] AutoRound now offers experimental support for the widely used **GGUF** format. We currently recommend using + RTN mode (--iters 0) for all the bits except 3. A more advanced algorithm tailored for some specific configurations is likely + to be introduced in the upcoming release. Example models are + available on the Intel Hugging Face space, including + [Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound-inc-v0](https://huggingface.co/Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound-inc-v0) + and [Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound-inc-v0](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound-inc-v0) +* [2025.05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer + to [DeepSeek-R1-0528-int2-mixed-sym-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-sym-inc), [DeepSeek-R1-0528-int4-sym-gptq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-gptq-inc-auto-round) + and [DeepSeek-R1-0528-int4-asym-awq-inc](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-awq-inc-auto-round) for + more details. * [2025/05] AutoRound has been integrated into **vLLM**. You can now run models in the AutoRound format directly with vLLM versions later than v0.85.post1. -* [2025/04] AutoRound provides some recipes for **Qwen3** series, please refer - to [Qwen3-8B-sym-recipe](https://huggingface.co/Intel/Qwen3-8B-int4-AutoRound-inc) and [Qwen3-14B-sym-recipe](https://huggingface.co/Intel/Qwen3-14B-int4-AutoRound-inc) for - more details. * [2025/04] AutoRound has been integrated into **Transformers**. You can run models in the AutoRound format directly with Transformers versions later than 4.51.3. @@ -277,9 +284,12 @@ models. Besides, recently 3 bits may have some accuracy issues in Transformers. **AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted within the community, **only 4-bits quantization is supported**. -**llmcompressor Format**: This format is for reusing llmcompressor format, **only INT8 W8A8 dynamic quantization is supported**. +**llmcompressor Format**: This format is for reusing llmcompressor format, **only INT8 W8A8 dynamic quantization is +supported**. + +**GGUF** Format: Experimental feature. This format is well-suited for CPU devices and is widely adopted by the +community. -**GGUF** Format: Experimental feature. This format is well-suited for CPU devices and is widely adopted by the community. ### Quantization Costs Testing was conducted on the Nvidia A100 80G using the nightly version of PyTorch 2.6.0.dev20241029+cu124. Please note @@ -340,18 +350,18 @@ Triton, but the final choice depends on factors such as bits, group_size, packin backend may not always be the most suitable for certain devices. Please refer to the following table for the details and specify the backend you want. -| Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements | -|--------------------------------------|--------------|---------|-----------|----------|-----------------|----------------------------------| -| ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch | +| Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements | +|--------------------------------------|--------------|---------|-----------|----------|-----------------|---------------------------------------| +| ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch | | itrex | cpu | 2,4,8 | BF16/FP16 | 1 | gptq_zp+-1/awq |
intel-extension-for-transformers | -| marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel | -| exllamav2 or
gptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel | -| exllamav2 or
gptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq | -| gptq:cuda | cuda | 2,3,4,8 | FP16 | 1 | gptq_zp+-1 | auto-gptq | -| triton | xpu/cuda | 2,4,8 | BF16/FP16 | 2 | gptq/gptq_zp+-1 | auto-round | -| awq | cuda | 4 | FP16 | 5 | awq | auto-awq | -| hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round | -| torch | xpu/cpu/cuda | 2,3,4,8 | BF16/FP16 | 0 | gptq/gptq_zp+-1 | auto-round | +| marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel | +| exllamav2 or
gptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel | +| exllamav2 or
gptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq | +| gptq:cuda | cuda | 2,3,4,8 | FP16 | 1 | gptq_zp+-1 | auto-gptq | +| triton | xpu/cuda | 2,4,8 | BF16/FP16 | 2 | gptq/gptq_zp+-1 | auto-round | +| awq | cuda | 4 | FP16 | 5 | awq | auto-awq | +| hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round | +| torch | xpu/cpu/cuda | 2,3,4,8 | BF16/FP16 | 0 | gptq/gptq_zp+-1 | auto-round | ```python from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/auto_round/autoround.py b/auto_round/autoround.py index ed38fc5f..51bffe93 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -716,6 +716,7 @@ def quantize_embedding_layer(self): Returns: bool: True if the quantization process completes without critical errors. """ + is_quantized = False for name, module in self.model.named_modules(): # Skip non-Embedding modules or layers not in config if not isinstance(module, torch.nn.Embedding) or name not in self.layer_config: @@ -726,7 +727,7 @@ def quantize_embedding_layer(self): # Skip layers that are not marked for quantization if not check_to_quantized(config): continue - + is_quantized = True config["scale_dtype"] = self.scale_dtype dtype = config["data_type"] @@ -779,7 +780,7 @@ def quantize_embedding_layer(self): # Release memory clear_memory() - return True + return is_quantized def quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> None: """Performs RTN quantization using input activation statistics (imatrix). diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index 209909e3..9225dbae 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -16,7 +16,7 @@ import random import torch -from datasets import Dataset, IterableDataset +from datasets import Dataset, IterableDataset, load_dataset, concatenate_datasets from datasets import Features, Sequence, Value from torch.utils.data import DataLoader import sys @@ -38,7 +38,12 @@ def register_dataset(name): """ def register(dataset): - CALIB_DATASETS[name] = dataset + if isinstance(name, list): + names = name + else: + names = [name] + for tmp_name in names: + CALIB_DATASETS[tmp_name] = dataset return dataset return register @@ -46,20 +51,22 @@ def register(dataset): def apply_chat_template_to_samples(samples, tokenizer, seqlen, system_prompt=None): rendered_messages = [] - if system_prompt is None: - system_prompt = "You are a helpful assistant." + # if system_prompt is None: ## remove system prompt as models like deepseek don't recommend using it + # system_prompt = "You are a helpful assistant." for text in samples: - if system_prompt == "": - message = [{"role": "user", "content": text}] + message = [] + if system_prompt is not None and system_prompt != "": + message.append({"role": "system", "content": system_prompt}) + + if isinstance(text, list) and isinstance(text[0], dict): + message += text else: - message = [{"role": "system", "content": system_prompt}, - {"role": "user", "content": text}] + message.append({"role": "user", "content": text}) try: chat_templated = tokenizer.apply_chat_template( message, tokenize=False, add_generation_prompt=True, - ) except: logger.warning( @@ -100,7 +107,7 @@ def default_tokenizer_function(examples): return default_tokenizer_function -@register_dataset("NeelNanda/pile-10k") +@register_dataset(["NeelNanda/pile-10k", "pile-10k"]) def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, apply_chat_template=False, system_prompt=None): """Returns a dataloader for the specified dataset and split. @@ -123,7 +130,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, system_prompt=system_prompt) try: - calib_dataset = load_dataset(dataset_name, split=split) + calib_dataset = load_dataset("NeelNanda/pile-10k", split=split) except Exception as e: import ssl error_message = str(e) @@ -141,7 +148,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split return calib_dataset -@register_dataset("swift/pile-val-backup") +@register_dataset(["swift/pile-val-backup", "pile-val-backup"]) def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup", split=None, seed=42, apply_chat_template=False, system_prompt=None): """Returns a dataloader for the specified dataset and split. @@ -168,14 +175,13 @@ def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup" from modelscope import MsDataset # pylint: disable=E0401 calib_dataset = MsDataset.load('swift/pile-val-backup', 'default', split=split).to_iterable_dataset() # , use_streaming=True - calib_dataset = calib_dataset.take(10000) - calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.shuffle(seed=seed).take(10000) calib_dataset = calib_dataset.map(tokenizer_function, batched=True) return calib_dataset -@register_dataset("BAAI/CCI3-HQ") +@register_dataset(["BAAI/CCI3-HQ", "CCI3-HQ"]) def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False, system_prompt=None): """Returns a dataloader for the specified dataset and split. @@ -196,15 +202,14 @@ def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=No tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, system_prompt=system_prompt) - calib_dataset = load_dataset(dataset_name, split='train', streaming=True) - calib_dataset = calib_dataset.take(10000) - calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = load_dataset("BAAI/CCI3-HQ", split='train', streaming=True) + calib_dataset = calib_dataset.shuffle(seed=seed).take(10000) calib_dataset = calib_dataset.map(tokenizer_function, batched=True) return calib_dataset -@register_dataset("codeparrot/github-code-clean") +@register_dataset(["codeparrot/github-code-clean", "github-code-clean"]) def get_github_code_clean_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42, apply_chat_template=False, system_prompt=None): """Returns a dataloader for the specified dataset and split. @@ -243,19 +248,115 @@ def default_tokenizer_function(examples): return default_tokenizer_function - from datasets import load_dataset - tokenizer_function = get_default_tokenizer_function() - calib_dataset = load_dataset(dataset_name, split='train', streaming=True, trust_remote_code=True) - calib_dataset = calib_dataset.take(10000) - calib_dataset = calib_dataset.shuffle(seed=seed) + dataset_mit = load_dataset("codeparrot/github-code-clean", "all-mit", split='train', + streaming=True, trust_remote_code=True).shuffle(seed=seed) + dataset_apache = load_dataset("codeparrot/github-code-clean", "all-apache-2.0", split='train', + streaming=True, trust_remote_code=True).shuffle(seed=seed) + calib_dataset = concatenate_datasets([dataset_mit, dataset_apache]) + calib_dataset = calib_dataset.shuffle(seed=seed).take(10000) ##TODO concat data'shuffle may have bugs calib_dataset = calib_dataset.map(tokenizer_function, batched=True) return calib_dataset -@register_dataset("madao33/new-title-chinese") +@register_dataset(["HuggingFaceH4/ultrachat_200k", "ultrachat_200k"]) +def get_ultrachat_dataset( + tokenizer, + seqlen, + dataset_name="HuggingFaceH4/ultrachat_200k", + split=None, + seed=42, + apply_chat_template=True, + system_prompt=None, +): + if split is None: + split = "train_sft" + all_splits = ["train_sft", "test_sft", "train_gen", "test_gen"] + if split not in all_splits: + raise ValueError("split must be one of {} for ultrachat_200k ".format(all_splits)) + + dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split=split, + streaming=True, trust_remote_code=True) + dataset = dataset.shuffle(seed=seed).take(20000) + + def is_instruct_tokenizer(tokenizer): + try: + out = tokenizer.apply_chat_template([{"role": "user", "content": "Hi"}]) + return bool(out and len(out) > 0) + except Exception: + return False + + is_instruct = is_instruct_tokenizer(tokenizer) + + if is_instruct and not apply_chat_template: + logger.info("Tokenizer looks like an instruct/chat model, but apply_chat_template=False. Setting to True.") + apply_chat_template = True + elif not is_instruct and apply_chat_template: + logger.info("Tokenizer is not an instruct/chat model, but apply_chat_template=True. Setting to False.") + apply_chat_template = False + + def tokenize_example_batch(examples): + if not apply_chat_template: + texts = [] + for message_list in examples["messages"]: + combined = "".join([msg["content"] for msg in message_list]) + texts.append(combined) + return tokenizer(texts, truncation=True, max_length=seqlen) + else: + return apply_chat_template_to_samples( + examples["messages"], tokenizer, seqlen, system_prompt=system_prompt + ) + + dataset = dataset.map(tokenize_example_batch, batched=True) + return dataset + + +@register_dataset(["openbmb/Ultra-FineWeb", "openbmb/Ultra-FineWeb"]) +def get_ultrafinweb_dataset( + tokenizer, + seqlen, + dataset_name="openbmb/Ultra-FineWeb", + split=None, + seed=42, + apply_chat_template=True, + system_prompt=None, +): + if split is not None: + if split not in ["en", "zh"]: + raise ValueError("split must be one of ['en', 'zh'] for Ultra-FineWeb dataset") + calib_dataset = load_dataset("openbmb/Ultra-FineWeb", split=split, + streaming=True, trust_remote_code=True) + else: + calib_dataset = load_dataset("openbmb/Ultra-FineWeb", split='en', + streaming=True, trust_remote_code=True) + # dataset_ch = load_dataset("openbmb/Ultra-FineWeb", split='zh', + # streaming=True, trust_remote_code=True).shuffle(seed=seed).take(2000) + + # calib_dataset = concatenate_datasets([dataset_en, dataset_ch]) ##concat dasetset could not shuffle + + + calib_dataset = calib_dataset.shuffle(seed=seed).take(20000) + + def get_default_tokenizer_function(): + def default_tokenizer_function(examples): + if not apply_chat_template: + example = tokenizer(examples["content"], truncation=True, max_length=seqlen) + else: + example = apply_chat_template_to_samples(examples["content"], tokenizer, seqlen, + system_prompt=system_prompt) + return example + + return default_tokenizer_function + + tokenizer_function = get_default_tokenizer_function() + + dataset = calib_dataset.map(tokenizer_function, batched=True) + return dataset + + +@register_dataset(["madao33/new-title-chinese", "new-title-chinese"]) def get_new_chinese_title_dataset( tokenizer, seqlen, @@ -307,7 +408,7 @@ def default_tokenizer_function(examples): tokenizer_function = get_tokenizer_function() - calib_dataset = load_dataset(dataset_name, split=split) + calib_dataset = load_dataset("madao33/new-title-chinese", split=split) calib_dataset = calib_dataset.shuffle(seed=seed) calib_dataset = calib_dataset.map(tokenizer_function, batched=True) @@ -697,4 +798,3 @@ def collate_batch(batch): calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch) return calib_dataloader - diff --git a/auto_round/utils.py b/auto_round/utils.py index 45b98998..be55e658 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -39,7 +39,7 @@ def __init__(self): self._support_format = ( "auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:gptqmodel", "auto_round:auto_awq", "itrex", "itrex_xpu", "fake") - self._gguf_format = tuple(GGUF_CONFIG.keys()) + self._gguf_format = tuple(sorted(GGUF_CONFIG.keys())) self._support_list = self._support_format + self._gguf_format def __contains__(self, key): diff --git a/docs/gguf_accuracy.md b/docs/gguf_accuracy.md deleted file mode 100644 index 555e2e17..00000000 --- a/docs/gguf_accuracy.md +++ /dev/null @@ -1,43 +0,0 @@ -1 We evaluate all models using the `fake` format, as lm-eval reports inaccurate accuracy for real GGUF format - - -lm-eval 0.48 - -```bash - lm-eval --model hf --model_args pretrained="./" --tasks mmlu,leaderboard_ifeval,leaderboard_mmlu_pro,gsm8k - --batch_size 16 -``` - -2 `lm-head` and `embedding` layers are not quantized in any of the following models. - -| Q4_K_S | Avg. | mmlu | mmlu_pro | if_eval | gsm8k | -|---------------------------|------------|--------|----------|----------|--------| -| Qwen2.5-7B-GGUF | 0.6366 | 0.7097 | 0.4385 | 0.61115 | 0.7870 | -| Qwen2.5-7B-AutoRound | **0.6529** | 0.7137 | 0.4471 | 0.6373 | 0.8135 | -| Llama-3.1-8B-GGUF | 0.5589 | 0.6609 | 0.3610 | 0.4949 | 0.7187 | -| Llama-3.1-8B-AutoRound | **0.5666** | 0.6627 | 0.3648 | 0.49965 | 0.7392 | -| Falcon3-7B-GGUF | 0.5179 | 0.6649 | 0.3607 | 0.3251 | 0.7210 | -| Falcon3-7B-AutoRound | **0.5261** | 0.6706 | 0.3841 | 0.31445 | 0.7354 | -| phi-4-GGUF | **0.5623** | 0.7648 | 0.5292 | 0.0590 | 0.8961 | -| phi-4-AutoRound | 0.5588 | 0.7673 | 0.5239 | 0.05175 | 0.8923 | - -| Q3_K_S | Avg. | mmlu | mmlu_pro | if_eval | gsm8k | -|---------------------------|------------|--------|----------|----------|--------| -| Qwen2.5-7B-GGUF | 0.5939 | 0.6936 | 0.4062 | 0.57675 | 0.6990 | -| Qwen2.5-7B-AutoRound | **0.6103** | 0.7002 | 0.4171 | 0.6194 | 0.7043 | -| Llama-3.1-8B-GGUF | 0.4903 | 0.6050 | 0.3260 | 0.44265 | 0.5876 | -| Llama-3.1-8B-AutoRound | **0.5511** | 0.6548 | 0.3533 | 0.4913 | 0.7051 | -| Falcon3-7B-GGUF | 0.4905 | 0.6434 | 0.3439 | 0.2871 | 0.6876 | -| Falcon3-7B-AutoRound | **0.5296** | 0.6520 | 0.3679 | 0.30745 | 0.7911 | -| phi-4-GGUF | **0.5527** | 0.7590 | 0.5072 | 0.0802 | 0.8643 | -| phi-4-AutoRound | 0.5523 | 0.7657 | 0.5124 | 0.0587 | 0.8726 | - -| Q2_K_S | Avg. | mmlu | mmlu_pro | if_eval | gsm8k | -|---------------------------|------------|--------|----------|----------|--------| -| Qwen2.5-7B-GGUF | 0.3942 | 0.5750 | 0.2701 | 0.4071 | 0.3245 | -| Qwen2.5-7B-AutoRound | **0.5133** | 0.6384 | 0.3383 | 0.4714 | 0.6050 | -| Falcon3-7B-GGUF | 0.1936 | 0.3491 | 0.1521 | 0.21615 | 0.0569 | -| Falcon3-7B-AutoRound | **0.3817** | 0.5607 | 0.2625 | 0.28955 | 0.4139 | -| phi-4-GGUF | 0.4438 | 0.6715 | 0.3807 | 0.0802 | 0.6429 | -| phi-4-AutoRound | **0.5113** | 0.7107 | 0.4383 | 0.08675 | 0.8097 | - diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 876a9dac..12081356 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -17,12 +17,13 @@ pip install auto-round The [NeelNanda/pile-10k](https://huggingface.co/datasets/NeelNanda/pile-10k) in huggingface is adopted as the default calibration data and will be downloaded automatically from the datasets Hub. Other available datasets include: - - `swift/pile-val-backup` from modelscope for addressing HF network issue - `BAAI/CCI3-HQ` for Chinese - `codeparrot/github-code-clean` for code +- `HuggingFaceH4/ultrachat_200k` for chat data - `madao33/new-title-chinese` for Chinese - `mbpp` for code +- `openbmb/Ultra-FineWeb` ### Customized Dataset