Skip to content

Commit 37341f5

Browse files
authored
refine inference step 2 (#498)
1 parent ea4d843 commit 37341f5

File tree

13 files changed

+459
-132
lines changed

13 files changed

+459
-132
lines changed

README.md

+3-8
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ pip install auto-round-lib
7373

7474
## Model Quantization
7575

76-
### Basic Usage (Gaudi2/CPU/GPU)
76+
### Basic Usage (Gaudi2/CPU/XPU/GPU)
7777

7878
A user guide detailing the full list of supported arguments is provided by calling ```auto-round -h``` on the terminal.
7979
Set the format you want in `format` and
@@ -268,7 +268,7 @@ autoround.save_quantized(output_dir, format='auto_round', inplace=True)
268268

269269
### Export Formats
270270
**AutoRound Format**: This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision
271-
inference. **[2,4] bits are supported**. However, it has not yet gained widespread community adoption.
271+
inference. **[2,3,4,8] bits are supported**. However, it has not yet gained widespread community adoption.
272272

273273
**AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the
274274
community, **[2,3,4,8] bits are supported**. However, **the
@@ -324,14 +324,9 @@ in [Gaudi Guide](https://docs.habana.ai/en/latest/).
324324
from transformers import AutoModelForCausalLM, AutoTokenizer
325325
from auto_round import AutoRoundConfig
326326

327-
backend = "auto" ##cpu, hpu, cuda
328-
quantization_config = AutoRoundConfig(
329-
backend=backend
330-
)
331327
quantized_model_path = "./tmp_autoround"
332328
model = AutoModelForCausalLM.from_pretrained(quantized_model_path,
333-
device_map=backend.split(':')[0],
334-
quantization_config=quantization_config)
329+
device_map="auto")
335330
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
336331
text = "There is a girl who likes adventure,"
337332
inputs = tokenizer(text, return_tensors="pt").to(model.device)

auto_round/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,14 @@
1515
from .mllm import AutoRoundMLLM
1616
from auto_round.utils import LazyImport
1717

18-
from auto_round.inference.auto_quantizer import AutoHfQuantizer,AutoRoundConfig
18+
def __getattr__(name):
19+
if name == 'AutoHfQuantizer':
20+
from auto_round.inference.auto_quantizer import AutoHfQuantizer
21+
return AutoHfQuantizer
22+
if name == 'AutoRoundConfig':
23+
from auto_round.inference.auto_quantizer import AutoRoundConfig
24+
return AutoRoundConfig
25+
26+
raise AttributeError(f"auto-round has no attribute '{name}'")
27+
1928
from .version import __version__

auto_round/autoround.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,8 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au
475475
for index in range(len(formats)):
476476
format = formats[index]
477477
if "auto_round" in format:
478-
if self.sym and ("gptq" not in format and "awq" not in format):
479-
format = format.replace('auto_round', 'auto_round:gptq')
478+
if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits==3:
479+
format = format.replace('auto_round', 'auto_round:auto_gptq')
480480
formats[index] = format
481481

482482
# Remove duplicates from formats list
@@ -496,6 +496,13 @@ def remove_duplicates(lst):
496496
# Save the quantized model in the specified formats
497497
folders = []
498498
for format in formats:
499+
if "gptq" in format and not self.sym:
500+
logger.warning(
501+
"The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop,"
502+
" particularly for 2-bit quantization and smaller models."
503+
" We recommend exporting to either the AutoAWQ format ( only 4 bits) or "
504+
"the AutoRound format(2/4/8 bits)."
505+
)
499506
save_format_ = format.replace(":", "-").replace("_", "-")
500507
save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir
501508
self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs)
@@ -1598,8 +1605,8 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
15981605
logger.warning(
15991606
"The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop,"
16001607
" particularly for 2-bit quantization and smaller models."
1601-
" We recommend exporting to either the AutoAWQ format (4 bits) or "
1602-
"the AutoRound format (2 bits) to enhance performance."
1608+
" We recommend exporting to either the AutoAWQ format ( only 4 bits) or "
1609+
"the AutoRound format(2/4/8 bits)."
16031610
)
16041611
if "awq" in format and not self.bits == 4:
16051612
raise ValueError("The AWQ format only supports W4 quantization ")

auto_round/export/export_to_autoround/export.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import auto_round.export.export_to_autoround.qlinear_triton_act
2525
import auto_round_extension.cuda.qlinear_tritonv2
26-
from auto_round.utils import get_layer_names_in_block, get_module, logger, set_module, supported_layer_types
26+
from auto_round.utils import get_module, logger, set_module, supported_layer_types
2727
import threadpoolctl as tctl
2828
import inspect
2929
from tqdm import tqdm
@@ -75,15 +75,15 @@ def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_
7575

7676
from auto_round_extension.cuda.qlinear_tritonv2 import QuantLinear
7777
return QuantLinear
78-
elif "auto_round" in backend and "gptq" in backend:
79-
from auto_round.export.export_to_autoround.qlinear_triton import QuantLinear ##no g_idx
80-
return QuantLinear
78+
elif "auto_round" in backend and "gptq" in backend and bits in (2, 4, 8):
79+
from auto_round.export.export_to_autoround.qlinear_triton import QuantLinear ##no g_idx
80+
return QuantLinear
8181
elif "awq" in backend:
8282
from ..export_to_awq.utils import WQLinear_GEMM
8383
return WQLinear_GEMM
8484
elif "gptqmodel" in backend:
8585
return auto_round_extension.cuda.qlinear_tritonv2.QuantLinear
86-
elif "gptq" in backend and not "gptqmodel" in backend: ## have g_idx
86+
elif "gptq" in backend and not "gptqmodel" in backend: ## have g_idx
8787
return get_autogptq_packing_qlinear(backend, bits, group_size, sym)
8888
else:
8989
assert False, f"only support auto_gptq, auto_awq and auto_round backend"
@@ -190,7 +190,9 @@ def pack_layer(layer_name, model, backend):
190190
new_layer.device = device
191191
set_module(model, layer_name, new_layer)
192192
qlayer = new_layer
193-
if sym:
193+
import auto_round.export.export_to_autoround.qlinear_triton
194+
if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear,
195+
auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)):
194196
zp = int(zp.flatten()[0])
195197

196198
qlayer.to("cpu")
@@ -248,7 +250,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
248250

249251
##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
250252
if (kwargs.get("sym") is None or kwargs.get("sym") == True) and ("gptq" not in backend and "awq" not in backend):
251-
backend = backend.replace('auto_round', 'auto_round:gptq')
253+
backend = backend.replace('auto_round', 'auto_round:auto_gptq')
252254

253255
model = kwargs["model"]
254256
safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"]
@@ -260,6 +262,9 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
260262
quantization_config["quant_method"] = "intel/auto-round"
261263

262264
quantization_config["backend"] = backend
265+
if quantization_config["bits"]==3:
266+
backend = "auto_round:auto_gptq"
267+
263268
tokenizer = kwargs.get("tokenizer", None)
264269
processor = kwargs.get("processor", None)
265270
extra_config = {}

auto_round/inference/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from auto_round.inference.convert_model import convert_hf_model, infer_target_device, post_init
1415

auto_round/inference/auto_quantizer.py

+26-23
Original file line numberDiff line numberDiff line change
@@ -155,26 +155,11 @@ def merge_quantization_configs(
155155
loading_attr_dict = quantization_config_from_args.get_loading_attributes() \
156156
if quantization_config_from_args is not None else None
157157
if isinstance(quantization_config, dict):
158-
if "auto-round" in quantization_config["quant_method"]:
158+
if "auto-round" in quantization_config[
159+
"quant_method"] or quantization_config_from_args.__class__.__name__ == "AutoRoundConfig":
159160
quantization_config = AutoRoundConfig.from_dict(quantization_config)
160161
else:
161-
if isinstance(quantization_config_from_args, (AutoRoundConfig)):
162-
logger.info(f"Loading quantized model in auto_round format.")
163-
tmp_backend = quantization_config["quant_method"]
164-
if "auto-round" not in tmp_backend and "gptq" not in tmp_backend and "awq" not in tmp_backend:
165-
logger.error("could not convert to auto_round format, currently only supports `gptq`,`awq` or "
166-
"`auto-round` format")
167-
exit(-1)
168-
target_backend = quantization_config["backend"] if "backend" in quantization_config else "auto"
169-
if loading_attr_dict is not None and "backend" in loading_attr_dict:
170-
target_backend = loading_attr_dict["backend"]
171-
loading_attr_dict.pop("backend")
172-
if "auto_round" not in target_backend:
173-
target_backend = f"auto_round:{tmp_backend}" #
174-
quantization_config = AutoRoundConfig.from_dict(quantization_config)
175-
setattr(quantization_config, "backend", target_backend)
176-
else:
177-
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) # pylint: disable=E1101
162+
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) # pylint: disable=E1101
178163

179164
if isinstance(quantization_config,
180165
(GPTQConfig, AwqConfig, AutoRoundConfig)) and quantization_config_from_args is not None:
@@ -265,8 +250,8 @@ def __init__(
265250

266251
def post_init(self):
267252
r"""Safety checker that arguments are correct."""
268-
if self.bits not in [2, 4, 8]:
269-
raise ValueError(f"Only support quantization to [2,4,8] bits but found {self.bits}")
253+
if self.bits not in [2, 3, 4, 8]:
254+
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
270255
if self.group_size != -1 and self.group_size <= 0:
271256
raise ValueError("group_size must be greater than 0 or equal to -1")
272257

@@ -278,6 +263,26 @@ def to_dict(self):
278263
config_dict = super().to_dict()
279264
return config_dict
280265

266+
@classmethod
267+
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
268+
quant_method = config_dict["quant_method"]
269+
if "auto-round" not in quant_method and "gptq" not in quant_method and "awq" not in quant_method:
270+
raise NotImplementedError(
271+
"Failed to convert to auto_round format. Only `gptqv1`, `awq`, and `auto-round` formats are supported."
272+
)
273+
274+
if "gptq" in quant_method and "meta" in config_dict:
275+
raise NotImplementedError(
276+
"Failed to convert gptq format to auto_round format. Only supports `gptqv1`")
277+
278+
if "awq" in quant_method and config_dict.get("version", "gemm") != "gemm":
279+
raise NotImplementedError(
280+
"Failed to convert awq format to auto_round format. Only supports awq format with gemm version")
281+
282+
if "auto-round" not in quant_method:
283+
config_dict["backend"] = f"auto_round:{quant_method}"
284+
return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
285+
281286

282287
class AutoRoundQuantizer(HfQuantizer):
283288
"""Quantizer of the AutoRound method, currently only triton and exllamav2 backend has been supported."""
@@ -306,7 +311,6 @@ def validate_environment(self, *args, **kwargs):
306311
"auto-round` or install from source")
307312

308313
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
309-
self.target_device = infer_target_device(self.device_map)
310314
if torch_dtype is None:
311315
torch_dtype = torch.float16
312316
elif torch_dtype != torch.float16:
@@ -330,8 +334,7 @@ class StoreAttr(object):
330334

331335
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
332336
if self.pre_quantized:
333-
target_device = self.target_device if hasattr(self, self.target_device) else infer_target_device(
334-
self.device_map)
337+
target_device = infer_target_device(self.device_map)
335338
model, used_backends = convert_hf_model(model, target_device)
336339
self.used_backends = used_backends
337340

auto_round/inference/backend.py

+68-5
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
126126
bits=[2, 4, 8],
127127
priority=1, feature_checks=[feature_multiply_checker_32],
128128
alias=["auto_round", "tritonv2"],
129-
requirements=["auto-round>=0.2"]
129+
requirements=["auto-round>=0.5.0"]
130130
)
131131

132132
BackendInfos['auto_round:tritonv2_zp'] = BackendInfo(device=["cuda"], sym=[True], ## asym has accuracy issue
@@ -135,7 +135,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
135135
bits=[2, 4, 8],
136136
priority=1, feature_checks=[feature_multiply_checker_32],
137137
alias=["tritonv2", "tritonv2_zp"],
138-
requirements=["auto-round>=0.5"]
138+
requirements=["auto-round>=0.5.0"]
139139
)
140140

141141
BackendInfos['gptqmodel:marlin'] = BackendInfo(device=["cuda"], sym=[True],
@@ -145,7 +145,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
145145
dtype=["float16", "bfloat16"],
146146
priority=6, feature_checks=[in_output_feature_multiply_checker_32],
147147
alias=["marlin", "gptqmodel"],
148-
requirements=["gptqmodel>=2.0"]
148+
requirements=["gptqmodel>=2.0"],
149149
)
150150

151151
BackendInfos['gptqmodel:marlin_zp'] = BackendInfo(device=["cuda"], sym=[True],
@@ -504,7 +504,7 @@ def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False):
504504
return QuantLinear
505505

506506

507-
def find_backend(target_backend: str, orig_backend: str = None) -> str | None:
507+
def find_backend(target_backend: str, orig_backend: str = None):
508508
"""
509509
Finds the matching backend key based on the target backend name or its aliases.
510510
@@ -620,7 +620,10 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f
620620
try:
621621
require_version(requirement)
622622
except ImportError:
623-
logger.error(f"pip install '{requirement}' ")
623+
if "gptqmodel" in requirement:
624+
logger.error(f"pip install -v '{requirement}' --no-build-isolation")
625+
else:
626+
logger.error(f"pip install '{requirement}' ")
624627
else:
625628
str_info = requirement()[1]
626629
logger.error(str_info)
@@ -633,3 +636,63 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f
633636
reverse=True)
634637

635638
return supported_backends[0]
639+
640+
641+
def get_highest_priority_backend(bits, sym, group_size, device, packing_format):
642+
supported_backends = []
643+
for key in BackendInfos.keys():
644+
backend = BackendInfos[key]
645+
# Check if device is supported by the backend
646+
if device not in backend.device:
647+
continue
648+
649+
# Check if bit-width is supported
650+
if bits not in backend.bits:
651+
continue
652+
653+
# Check if group_size is valid (if required by backend)
654+
if backend.group_size is not None and group_size not in backend.group_size:
655+
continue
656+
657+
# Check if symmetric/asymmetric quantization is supported
658+
if sym not in backend.sym:
659+
continue
660+
661+
# Check if the format is convertible when packing formats differ
662+
if packing_format == backend.packing_format or packing_format in backend.convertable_format:
663+
pass
664+
else:
665+
continue
666+
supported_backends.append(key)
667+
668+
if len(supported_backends) > 0:
669+
670+
supported_backends = sorted(supported_backends,
671+
key=lambda support_backend: BackendInfos[support_backend].priority,
672+
reverse=True)
673+
return supported_backends[0]
674+
else:
675+
return None
676+
677+
678+
def process_requirement(requirements: list):
679+
gptqmodel_requirements = None
680+
other_requirements = []
681+
for requirement in requirements:
682+
if "gptqmodel" in requirement:
683+
gptqmodel_requirements = requirement
684+
else:
685+
other_requirements.append(requirement)
686+
687+
infos = []
688+
689+
if gptqmodel_requirements is not None:
690+
infos.append(f"pip install -v '{gptqmodel_requirements}' --no-build-isolation")
691+
infos.append(f"pip install 'numpy<2.0'")
692+
693+
other_info = f"pip install"
694+
if len(other_requirements) > 0:
695+
for requirement in other_requirements:
696+
other_info += f" {requirement}"
697+
infos.append(other_info)
698+
return infos

0 commit comments

Comments
 (0)