From 839ba16885f68b5466545ea478f61290da3c88d9 Mon Sep 17 00:00:00 2001 From: shiqi Date: Mon, 21 Apr 2025 21:52:04 -0400 Subject: [PATCH 1/5] update_v2 --- auto_round/data_type/int.py | 14 +++++++------- auto_round/wrapper.py | 13 ++++++++++--- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 899abf17..b2843feb 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -64,17 +64,17 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal ## the values should be positive -def double_quant_tensor(tensor, bits, q_scale_thresh): +def double_quant_tensor(tensor, bits, q_scale_thresh, coeef): maxq = 2 ** bits - 1 wmax = torch.clamp(tensor.max(-1)[0], min=0) - scale = torch.clamp(wmax / maxq, q_scale_thresh) + scale = torch.clamp(wmax / maxq, q_scale_thresh) * coeef scale = scale.view(-1, 1) qdq_tensor = torch.clamp(round_ste(tensor / scale), max=maxq) * scale return qdq_tensor, scale @register_dtype("int_asym_dq") -def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, +def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, k_min=1.0, k_max=1.0, scale_dtype=torch.float16, tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, super_group_size=8, super_bits=6, **kwargs): """Quantize and de-quantize tensor asymmetrically. @@ -104,8 +104,8 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ wmin_tmp = tensor_min wmax_tmp = tensor_max if isinstance(min_scale, torch.Tensor): - wmin = wmin_tmp * min_scale - wmax = wmax_tmp * max_scale + wmin = wmin_tmp * min_scale #* k_min + wmax = wmax_tmp * max_scale #* k_max else: wmin = wmin_tmp wmax = wmax_tmp @@ -116,8 +116,8 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ wmin_m = wmin_m.view(-1, super_group_size) ##conduct double quant - scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh) - wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh) + scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh, k_max) + wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh, k_min) scale = scale.view(-1, 1) scale = torch.clamp(scale, q_scale_thresh) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index a624d9f2..8cbb5faa 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -115,6 +115,8 @@ def _init_tuning_params_and_quant_func(self): shape = get_scale_shape(orig_weight, orig_layer.group_size) self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) self._init_params("max_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) + self._init_params("k_min", p_dtype, shape, 1.0, self.enable_minmax_tuning) + self._init_params("k_max", p_dtype, shape, 1.0, self.enable_minmax_tuning) self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym) @@ -150,7 +152,7 @@ def _init_params(self, name, dtype, shape, value, tunable): setattr(self, name, p) - def _qdq_weight(self, value, min_scale, max_scale): + def _qdq_weight(self, value, min_scale, max_scale, k_min, k_max): """Quantizes and dequantizes weights with tuning parameters. Args: @@ -163,6 +165,8 @@ def _qdq_weight(self, value, min_scale, max_scale): """ min_scale.data.clamp_(0, 1.0) max_scale.data.clamp_(0, 1.0) + k_min.data.clamp_(0, 1.0) + k_max.data.clamp_(0, 1.0) weight = self.orig_layer.weight if weight.device.type == 'meta': weight = self.orig_layer.get_weight().to(self.device) @@ -181,6 +185,8 @@ def _qdq_weight(self, value, min_scale, max_scale): v=value, min_scale=min_scale, max_scale=max_scale, + k_min=k_min, + k_max=k_max, scale_dtype=self.orig_layer.scale_dtype, tensor_min=self.weight_min, tensor_max=self.weight_max, @@ -239,11 +245,12 @@ def unwrapper(self, best_params): v = best_params.get('value', torch.tensor(0.0)).to(self.device) min_scale = best_params.get('min_scale', torch.tensor(1.0)).to(self.device) max_scale = best_params.get('max_scale', torch.tensor(1.0)).to(self.device) - + k_min = best_params.get('k_min', torch.tensor(1.0)).to(self.device) + k_max = best_params.get('k_max', torch.tensor(1.0)).to(self.device) if self.orig_layer.weight.device.type == 'meta': self.orig_layer.to(self.device) ##unwrapper weight - qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) + qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale, k_min, k_max) self.orig_layer.weight.data.copy_(qdq_weight) self.orig_layer.weight.grad = None From 32e257d4a047c51bb11c0916e90edbf4a3b6c56e Mon Sep 17 00:00:00 2001 From: shiqi Date: Tue, 22 Apr 2025 03:37:27 -0400 Subject: [PATCH 2/5] update --- auto_round/autoround.py | 2 ++ auto_round/data_type/int.py | 11 +++++------ auto_round/wrapper.py | 24 +++++++++++++----------- test_flow.sh | 12 ++++++++++++ test_gguf.sh | 11 +++++++++++ 5 files changed, 43 insertions(+), 17 deletions(-) create mode 100644 test_flow.sh create mode 100644 test_gguf.sh diff --git a/auto_round/autoround.py b/auto_round/autoround.py index f63827cf..de3a3f3e 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -1351,11 +1351,13 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch for n, m in block.named_modules(): if hasattr(m, "orig_layer"): for key in m.params.keys(): + # breakpoint() if "min" in key or "max" in key: minmax_params.append(m.params[key]) else: round_params.append(m.params[key]) + if self.enable_minmax_tuning: optimizer = self.optimizer( [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index b2843feb..963e101f 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -74,7 +74,7 @@ def double_quant_tensor(tensor, bits, q_scale_thresh, coeef): @register_dtype("int_asym_dq") -def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, k_min=1.0, k_max=1.0, scale_dtype=torch.float16, +def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, k_wm=1.0, k_scale=1.0, scale_dtype=torch.float16, tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, super_group_size=8, super_bits=6, **kwargs): """Quantize and de-quantize tensor asymmetrically. @@ -104,8 +104,8 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ wmin_tmp = tensor_min wmax_tmp = tensor_max if isinstance(min_scale, torch.Tensor): - wmin = wmin_tmp * min_scale #* k_min - wmax = wmax_tmp * max_scale #* k_max + wmin = wmin_tmp * min_scale #* k_wm + wmax = wmax_tmp * max_scale #* k_scale else: wmin = wmin_tmp wmax = wmax_tmp @@ -114,10 +114,9 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ scale = scale.view(-1, super_group_size) wmin_m = -wmin # pylint: disable=E1130 wmin_m = wmin_m.view(-1, super_group_size) - ##conduct double quant - scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh, k_max) - wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh, k_min) + scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh, k_scale) + wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh, k_wm) scale = scale.view(-1, 1) scale = torch.clamp(scale, q_scale_thresh) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 8cbb5faa..27d83879 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -115,8 +115,10 @@ def _init_tuning_params_and_quant_func(self): shape = get_scale_shape(orig_weight, orig_layer.group_size) self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) self._init_params("max_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) - self._init_params("k_min", p_dtype, shape, 1.0, self.enable_minmax_tuning) - self._init_params("k_max", p_dtype, shape, 1.0, self.enable_minmax_tuning) + tmp_a = (self.weight_min.view(-1,self.orig_layer.super_group_size)) + tmp_b = (self.weight_max.view(-1,self.orig_layer.super_group_size)) + self._init_params("k_wm", p_dtype, get_scale_shape(tmp_a, orig_layer.super_group_size), 1.0, self.enable_minmax_tuning) + self._init_params("k_scale", p_dtype, get_scale_shape(tmp_b, orig_layer.super_group_size), 1.0, self.enable_minmax_tuning) self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym) @@ -152,7 +154,7 @@ def _init_params(self, name, dtype, shape, value, tunable): setattr(self, name, p) - def _qdq_weight(self, value, min_scale, max_scale, k_min, k_max): + def _qdq_weight(self, value, min_scale, max_scale, k_wm, k_scale): """Quantizes and dequantizes weights with tuning parameters. Args: @@ -165,8 +167,8 @@ def _qdq_weight(self, value, min_scale, max_scale, k_min, k_max): """ min_scale.data.clamp_(0, 1.0) max_scale.data.clamp_(0, 1.0) - k_min.data.clamp_(0, 1.0) - k_max.data.clamp_(0, 1.0) + k_wm.data.clamp_(0, 1.0) + k_scale.data.clamp_(0, 1.0) weight = self.orig_layer.weight if weight.device.type == 'meta': weight = self.orig_layer.get_weight().to(self.device) @@ -185,8 +187,8 @@ def _qdq_weight(self, value, min_scale, max_scale, k_min, k_max): v=value, min_scale=min_scale, max_scale=max_scale, - k_min=k_min, - k_max=k_max, + k_wm=k_wm, + k_scale=k_scale, scale_dtype=self.orig_layer.scale_dtype, tensor_min=self.weight_min, tensor_max=self.weight_max, @@ -245,12 +247,12 @@ def unwrapper(self, best_params): v = best_params.get('value', torch.tensor(0.0)).to(self.device) min_scale = best_params.get('min_scale', torch.tensor(1.0)).to(self.device) max_scale = best_params.get('max_scale', torch.tensor(1.0)).to(self.device) - k_min = best_params.get('k_min', torch.tensor(1.0)).to(self.device) - k_max = best_params.get('k_max', torch.tensor(1.0)).to(self.device) + k_wm = best_params.get('k_wm', torch.tensor(1.0)).to(self.device) + k_scale = best_params.get('k_scale', torch.tensor(1.0)).to(self.device) if self.orig_layer.weight.device.type == 'meta': self.orig_layer.to(self.device) ##unwrapper weight - qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale, k_min, k_max) + qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale, k_wm, k_scale) self.orig_layer.weight.data.copy_(qdq_weight) self.orig_layer.weight.grad = None @@ -357,7 +359,7 @@ def forward(self, x): torch.Tensor: Output tensor after applying the wrapped layer. """ x = x.to(self.device) - weight_q, _, _ = self._qdq_weight(self.value, self.min_scale, self.max_scale) + weight_q, _, _ = self._qdq_weight(self.value, self.min_scale, self.max_scale, self.k_wm, self.k_scale) if self.enable_act_quant: act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None diff --git a/test_flow.sh b/test_flow.sh new file mode 100644 index 00000000..b2d13b3f --- /dev/null +++ b/test_flow.sh @@ -0,0 +1,12 @@ +for model_name in "falcon-three-7b" "Meta-Llama-3.1-8B-Instruct" "phi-4"; do +device=1 +CUDA_VISIBLE_DEVICES=$device python -m auto_round \ + --format gguf:q4_k_s \ + --iters 200 \ + --model ${model_name} \ + --output_dir /data5/shiqi/${model_name}_tune \ + --eval_bs 16 \ + --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ + --eval_task_by_task \ + 2>&1 | tee /data5/shiqi/log/gguf_test/${model_name}_tune.log +done \ No newline at end of file diff --git a/test_gguf.sh b/test_gguf.sh new file mode 100644 index 00000000..24d41cb4 --- /dev/null +++ b/test_gguf.sh @@ -0,0 +1,11 @@ +model="/models/Qwen2.5-7B-Instruct/" +device=1 +CUDA_VISIBLE_DEVICES=$device python -m auto_round \ + --format gguf:q4_k_s \ + --iters 200 \ + --model $model \ + --output /data5/shiqi/ \ + --eval_bs 16 \ + --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ + --eval_task_by_task \ + 2>&1 | tee /data5/shiqi/log/gguf_test/gguf_double_Qwen2.5-7B-Instruct.log \ No newline at end of file From def32cccf86b91af42e27e023fa0fc67fc8efff7 Mon Sep 17 00:00:00 2001 From: shiqi Date: Tue, 22 Apr 2025 21:02:32 -0400 Subject: [PATCH 3/5] sh --- test_flow.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_flow.sh b/test_flow.sh index b2d13b3f..4ddce3b5 100644 --- a/test_flow.sh +++ b/test_flow.sh @@ -3,7 +3,7 @@ device=1 CUDA_VISIBLE_DEVICES=$device python -m auto_round \ --format gguf:q4_k_s \ --iters 200 \ - --model ${model_name} \ + --model /models/${model_name} \ --output_dir /data5/shiqi/${model_name}_tune \ --eval_bs 16 \ --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ From 70a1cd9278fb32c49b49434026941ea81cc7472f Mon Sep 17 00:00:00 2001 From: shiqi Date: Wed, 23 Apr 2025 21:42:32 -0400 Subject: [PATCH 4/5] update sh --- test_flow.sh | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test_flow.sh b/test_flow.sh index 4ddce3b5..870de327 100644 --- a/test_flow.sh +++ b/test_flow.sh @@ -1,12 +1,17 @@ -for model_name in "falcon-three-7b" "Meta-Llama-3.1-8B-Instruct" "phi-4"; do -device=1 +for model_name in "Qwen2.5-7B-Instruct" "falcon-three-7b" "Meta-Llama-3.1-8B-Instruct" "phi-4"; do +device=6 +format=fake CUDA_VISIBLE_DEVICES=$device python -m auto_round \ - --format gguf:q4_k_s \ + --format ${format} \ + --data_type int_asym_dq \ + --group_size 32 \ + --super_bits 6 \ + --super_group_size 8 \ + --bits 4 \ --iters 200 \ --model /models/${model_name} \ - --output_dir /data5/shiqi/${model_name}_tune \ + --output_dir /data5/shiqi/${format}_${model_name}_tune \ --eval_bs 16 \ --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ - --eval_task_by_task \ - 2>&1 | tee /data5/shiqi/log/gguf_test/${model_name}_tune.log + 2>&1 | tee /data5/shiqi/log/gguf_test/${format}_${model_name}_tune.log done \ No newline at end of file From 975954b156dc010c090d6d132da7d95a5533e758 Mon Sep 17 00:00:00 2001 From: shiqi Date: Fri, 25 Apr 2025 05:02:12 -0400 Subject: [PATCH 5/5] update --- auto_round/data_type/int.py | 6 +- auto_round/script/llm.py | 329 ++++++++++++++++++++---------------- q2_test_tune.sh | 19 +++ 3 files changed, 201 insertions(+), 153 deletions(-) create mode 100644 q2_test_tune.sh diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 963e101f..33b846ec 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -122,11 +122,11 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ scale = torch.clamp(scale, q_scale_thresh) wmin_m = wmin_m.view(-1, 1) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w + round_ste(wmin_m / scale), 0, maxq) + int_w = round_ste((tensor + wmin_m) / scale + v) + q = torch.clamp(int_w, 0, maxq) qdq_result = (scale * q - wmin_m).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - zp = round_ste(wmin_m / scale) # remove this later + #zp = round_ste(wmin_m / scale) # remove this later return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index b3685d9f..d4c16953 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - + # Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,23 +35,23 @@ clear_memory, get_device_and_parallelism, set_cuda_visible_devices) - + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - - + + class BasicArgumentParser(argparse.ArgumentParser): - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_argument( "--model", "--model_name", "--model_name_or_path", default="facebook/opt-125m", help="model name or path") - + self.add_argument('--eval', action='store_true', help="whether to use eval only mode") - + self.add_argument("--bits", default=4, type=int, help="number of weight bits") - + self.add_argument("--eval_bs", default=None, type=int, help="batch size in evaluation") - + self.add_argument( "--device", "--devices", @@ -62,38 +62,38 @@ def __init__(self, *args, **kwargs): "The default is set to cuda:0," "allowing for automatic detection and switch to HPU or CPU." "set --device 0,1,2 to use multiple cards.") - + self.add_argument("--asym", action='store_true', help="whether to use asym quantization") - + self.add_argument( "--dataset", default="NeelNanda/pile-10k", type=str, help="the dataset for quantization training") - + self.add_argument( "--minmax_lr", default=None, type=float, help="minmax learning rate, if None, it will beset to be the same with lr") - + self.add_argument("--seed", default=42, type=int, help="random seed") - + self.add_argument("--adam", action='store_true', help="whether to use adam optimizer instead of SignSGD") - + self.add_argument("--gradient_accumulate_steps", default=1, type=int, help="gradient accumulate steps") - + self.add_argument("--nblocks", default=1, type=int, help="how many blocks to tune together") - + self.add_argument("--low_gpu_mem_usage", action='store_true', help="offload intermediate features to cpu") - + self.add_argument("--format", default="auto_round", type=str, help="the format to save the model") - + self.add_argument("--data_type", "--dtype", default='int', help="data type for tuning, 'int', 'mx_fp' and etc") - + self.add_argument( "--scale_dtype", default='fp16', choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], help="scale data type to use for quantization") - + self.add_argument( "--tasks", "--task", @@ -102,35 +102,35 @@ def __init__(self, *args, **kwargs): "openbookqa,boolq,arc_easy,arc_challenge", default=None, help="lm-eval tasks") - + self.add_argument( "--output_dir", default="./tmp_autoround", type=str, help="the directory to save quantized model") - + self.add_argument("--disable_eval", action='store_true', help="whether to disable lm-eval evaluation after tuning") - + self.add_argument( "--eval_task_by_task", action="store_true", help="whether to eval task by task.") - + self.add_argument("--disable_amp", action='store_true', help="disable amp") - + self.add_argument( "--disable_minmax_tuning", action='store_true', help="whether to disable enable weight minmax tuning") - + self.add_argument("--enable_norm_bias_tuning", action='store_true', help="whether to enable norm bias tuning") - + self.add_argument( "--disable_trust_remote_code", action='store_true', help="whether to disable trust_remote_code") - + self.add_argument( "--disable_quanted_input", action='store_true', help="whether to disuse the output of quantized block to tune the next block") - + self.add_argument("--quant_lm_head", action='store_true', help="whether to quant lm_head") - + self.add_argument( "--low_cpu_mem_mode", default=0, @@ -143,58 +143,58 @@ def __init__(self, *args, **kwargs): "2 means choose layer-wise mode, load the weights of each layer from disk when tuning," " minimum memory consumption and also slowest running speed." "others means not use low cpu memory. Default to 0, not use low cpu memory.") - + self.add_argument( "--low_cpu_mem_tmp_dir", default=None, type=str, help="temporary work space to store the temporary files " "when using low cpu memory mode. Will remove after tuning.") - + self.add_argument( "--model_dtype", default=None, type=str, choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], help="force to convert the dtype, some backends supports fp16 dtype better") - + self.add_argument("--act_bits", default=16, type=int, help="activation bits") - + self.add_argument( "--fp_layers", default="", type=str, help="list of Layer names to maintain original data type") - + self.add_argument( "--not_use_best_mse", action='store_true', help="whether to use the iter of best mes loss in the tuning phase") - + self.add_argument( "--to_quant_block_names", default=None, type=str, help="Names of quantitative blocks, please use commas to separate them.") - + self.add_argument("--enable_torch_compile", action='store_true', help="whether to enable torch compile") - + self.add_argument("--act_data_type", "--act_dtype", default=None, type=str, help="activation data type") - + self.add_argument("--disable_act_dynamic", action='store_true', help="activation static quantization") - + self.add_argument("--disable_deterministic_algorithms", action='store_true', help="disable torch deterministic algorithms.") - + self.add_argument("--device_map", default=None, type=str, help="device_map for block in tuning phase") - + self.add_argument( "--super_group_size", default=None, type=int, help="the number of super group size when use double quant.") - + self.add_argument( "--super_bits", default=None, type=int, help="number of scale and mins quant bits for double quant.") - - + + class EvalArgumentParser(argparse.ArgumentParser): - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_argument( @@ -209,7 +209,7 @@ def __init__(self, *args, **kwargs): "The default is set to cuda:0," "allowing for automatic detection and switch to HPU or CPU." "set --device 0,1,2 to use multiple cards.") - + self.add_argument("--tasks", "--task", default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \ "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge", @@ -218,148 +218,148 @@ def __init__(self, *args, **kwargs): "--disable_trust_remote_code", action='store_true', help="whether to disable trust_remote_code") self.add_argument("--eval_bs", "--bs", "--batch_size", default=None, type=int, help="batch size in evaluation") self.add_argument("--eval_task_by_task", action='store_true', help="whether to eval task by task.") - - + + def setup_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int, help="train batch size") - + parser.add_argument("--iters", "--iter", default=200, type=int, help="iteration to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=2048, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=128, type=int, help="number of samples") - + parser.add_argument( "--lr", default=None, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() return args - - + + def setup_best_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int, help="train batch size") - + parser.add_argument("--iters", "--iter", default=1000, type=int, help="iterations to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=2048, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=512, type=int, help="number of samples") - + parser.add_argument( "--lr", default=None, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() args.low_gpu_mem_usage = True - + return args - - + + def setup_light_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int, help="train batch size") - + parser.add_argument("--iters", "--iter", default=50, type=int, help="iterations to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=2048, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=128, type=int, help="number of samples") - + parser.add_argument( "--lr", default=5e-3, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() - + return args - - + + def setup_fast_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=4, type=int, help="train batch size") - + parser.add_argument("--iters", default=200, type=int, help="iterations to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=512, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=128, type=int, help="number of samples") - + parser.add_argument( "--lr", default=None, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() - + return args - - + + def setup_eval_parser(): parser = EvalArgumentParser() args = parser.parse_args() return args - - + + def tune(args): if args.disable_eval: logging.warning("`disable_eval` is deprecated and is now set by default.") - + if args.eval_bs is None: args.eval_bs = "auto" - + import transformers - + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig - + from auto_round.utils import detect_device, get_library_version from auto_round.utils import logger, _gguf_args_check - + if args.format is None: args.format = "auto_round" - + formats = args.format.lower().replace(' ', '').split(",") from auto_round.utils import supported_formats for format in formats: if format not in supported_formats: raise ValueError(f"{format} is not supported, we only support {supported_formats}") - + args = _gguf_args_check(args) - + if "auto_gptq" in args.format and args.asym is True: logger.warning("the auto_gptq kernel has issues with asymmetric quantization. " "It is recommended to use sym quantization or --format='auto_round'") - + if "marlin" in args.format and args.asym is True: assert False, "marlin backend only supports sym quantization, please remove --asym" - + ##must set this before import torch set_cuda_visible_devices(args.device) device_str, use_auto_mapping = get_device_and_parallelism(args.device) - + import torch if not args.disable_deterministic_algorithms: torch.use_deterministic_algorithms(True, warn_only=True) # logger.info("`torch.use_deterministic_algorithms` is enabled by default for reproducibility " # "and can be disabled using the `--disable_deterministic_algorithms` argument.") - + if args.enable_torch_compile: logger.info("`torch.compile` is enabled to reduce tuning costs. " "If it causes issues, you can disable it by remove `--enable_torch_compile` argument.") - + model_name = args.model if model_name[-1] == "/": model_name = model_name[:-1] @@ -367,7 +367,7 @@ def tune(args): torch_dtype = "auto" if device_str is not None and "hpu" in device_str: torch_dtype = torch.bfloat16 - + from auto_round.utils import llm_load_model model, tokenizer, low_cpu_mem_usage = llm_load_model( model_name, @@ -378,25 +378,25 @@ def tune(args): low_cpu_mem_mode=args.low_cpu_mem_mode, low_cpu_mem_tmp_dir=args.low_cpu_mem_tmp_dir, model_dtype=args.model_dtype) - + from auto_round import AutoRound, AutoRoundAdam - + seqlen = args.seqlen - + if hasattr(tokenizer, "model_max_length"): if tokenizer.model_max_length < seqlen: logger.info( f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length") seqlen = min(seqlen, tokenizer.model_max_length) args.seqlen = seqlen - + if "bloom" in model_name: args.low_gpu_mem_usage = False - + round = AutoRound if args.adam: round = AutoRoundAdam - + layer_config = {} for n, m in model.named_modules(): if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D): @@ -405,7 +405,7 @@ def tune(args): logger.info( f"{n} will not be quantized due to its shape not being divisible by 32," " resulting in an exporting issue to autogptq") - + not_quantize_layer_names = get_fp_layer_names(model, args.fp_layers) for name in not_quantize_layer_names: layer_config[name] = {"bits": 16} @@ -415,7 +415,7 @@ def tune(args): if "auto_round" not in format and "fake" not in format and "awq" not in format: ##TODO gptq could support some mixed precision config logger.warning(f"mixed precision exporting does not support {format} currently") - + lm_head_layer_name = "lm_head" for n, _ in model.named_modules(): lm_head_layer_name = n @@ -430,7 +430,7 @@ def tune(args): f"reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been " f"supported currently") break - + if args.quant_lm_head: layer_config[lm_head_layer_name] = {"bits": args.bits} for format in formats: @@ -438,16 +438,16 @@ def tune(args): auto_round_formats = [s for s in supported_formats if s.startswith("auto_round")] raise ValueError( f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}") - + if "auto_awq" in args.format: from auto_round.utils import check_awq_gemm_compatibility awq_supported, info = check_awq_gemm_compatibility( model, args.bits, args.group_size, not args.asym, layer_config) if not awq_supported: logger.warning(f"The AutoAWQ format may not be supported due to {info}") - + enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False - + autoround = round( model, tokenizer, @@ -484,59 +484,60 @@ def tune(args): super_group_size=args.super_group_size, super_bits=args.super_bits, ) - + model_name = args.model.rstrip("/") if model_name.split('/')[-1].strip('.') == "": export_dir = os.path.join(args.output_dir, f"w{args.bits}g{args.group_size}") else: export_dir = os.path.join(args.output_dir, model_name.split('/')[-1] + f"-w{args.bits}g{args.group_size}") - + model, folders = autoround.quantize_and_save(export_dir, format=args.format) - + if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2: import shutil shutil.rmtree(args.low_cpu_mem_tmp_dir, ignore_errors=True) - + model.eval() clear_memory() - + lm_eval_version = get_library_version("lm-eval") - + eval_folder = folders[-1] if args.tasks is None or args.tasks == "" or eval_folder is None: return - + tasks = args.tasks if isinstance(tasks, str): tasks = tasks.split(',') - + from lm_eval.utils import make_table # pylint: disable=E0401 - + logger.info(f"Using lm-eval version {lm_eval_version}") eval_gguf_model = False for file in os.listdir(eval_folder): if file.endswith("gguf"): eval_gguf_model = True break - + if args.act_bits <= 8 or eval_gguf_model: if eval_gguf_model: # gguf floder only contains one file for file in os.listdir(eval_folder): gguf_file = file user_model = AutoModelForCausalLM.from_pretrained( - eval_folder, gguf_file=gguf_file, device_map="auto" if use_auto_mapping else None) + eval_folder, gguf_file=gguf_file, device_map="auto") + user_model = user_model.to(torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(eval_folder, gguf_file=gguf_file) else: if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: from accelerate.big_modeling import dispatch_model - + dispatch_model(model, model.hf_device_map) user_model = model else: device_str = detect_device(device_str) user_model = model.to(device_str) - + if args.eval_task_by_task: eval_task_by_task( user_model, tokenizer=tokenizer, device=device_str, tasks=args.tasks, batch_size=args.eval_bs) @@ -558,8 +559,8 @@ def tune(args): res = simple_evaluate( model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs) print(make_table(res)) - - + + def _eval_init(tasks, model_path, device, disable_trust_remote_code=False): set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) @@ -569,35 +570,57 @@ def _eval_init(tasks, model_path, device, disable_trust_remote_code=False): if isinstance(tasks, str): tasks = tasks.split(',') return tasks, model_args, device_str - - + + def eval(args): tasks, model_args, device_str = _eval_init(args.tasks, args.model, args.device, args.disable_trust_remote_code) - + # load after _eval_int in order to make sure import torch after set CUDA_VISBILE_DEVICES - from auto_round.eval.evaluation import simple_evaluate - - res = simple_evaluate(model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs) - - from lm_eval.utils import make_table # pylint: disable=E0401 - print(make_table(res)) - - + from auto_round.eval.evaluation import simple_evaluate, simple_evaluate_user_model + + is_gguf_file = False + if os.path.isfile(args.model) and args.model.endswith(".gguf"): + is_gguf_file = True + gguf_file = os.path.basename(args.model) + model = os.path.dirname(args.model) + else: + for file in os.listdir(model): + if file.endswith(".gguf"): + is_gguf_file = True + gguf_file = file + if is_gguf_file: + import torch + from transformers import AutoTokenizer, AutoModelForCausalLM + from lm_eval.utils import make_table # pylint: disable=E0401 + tokenizer = AutoTokenizer.from_pretrained(model, gguf_file=gguf_file) + user_model = AutoModelForCausalLM.from_pretrained(model, gguf_file=gguf_file, device_map="auto") + user_model = user_model.to(torch.bfloat16) + res = simple_evaluate_user_model( + user_model, tokenizer, tasks=tasks, batch_size=args.eval_bs, device=device_str) + print(make_table(res)) + else: + res = simple_evaluate( + model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs) + + from lm_eval.utils import make_table # pylint: disable=E0401 + print(make_table(res)) + + def eval_task_by_task( model, device=None, tasks=None, tokenizer=None, batch_size=None, max_batch_size=64, trust_remote_code=True): set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) - + # load after _eval_int in order to make sure import torch after set CUDA_VISBILE_DEVICES import traceback from auto_round.utils import logger from lm_eval import simple_evaluate as lm_simple_evaluate from lm_eval.models.huggingface import HFLM from transformers import AutoModelForCausalLM, AutoTokenizer - + from auto_round import AutoRoundConfig # pylint: disable=E0611 if batch_size is None: - batch_size = "auto" + batch_size = "auto:8" is_gguf_file = False if not isinstance(model, str): parallelism = False @@ -612,8 +635,11 @@ def eval_task_by_task( is_gguf_file = True gguf_file = file if is_gguf_file: + import torch tokenizer = AutoTokenizer.from_pretrained(model, gguf_file=gguf_file) + # float32 model = AutoModelForCausalLM.from_pretrained(model, gguf_file=gguf_file, device_map="auto") + model = model.to(torch.bfloat16) hflm = HFLM( pretrained=model, tokenizer=tokenizer, @@ -622,13 +648,15 @@ def eval_task_by_task( max_batch_size=max_batch_size, parallelize=parallelism, trust_remote_code=trust_remote_code) - + if isinstance(tasks, str): tasks = tasks.replace(" ", "").split(",") - + from lm_eval.utils import make_table # pylint: disable=E0401 res_all = {} res_keys = ["results", "versions", "n-shot", "higher_is_better"] + import time + st = time.time() for task in tasks: try: res = lm_simple_evaluate(model=hflm, model_args=None, device=device_str, tasks=task, batch_size=batch_size) @@ -646,11 +674,12 @@ def eval_task_by_task( except Exception as e: traceback.print_exc() continue - + if not res_all: res_all = res else: for key in res_keys: res_all[key].update(res[key]) print(make_table(res_all)) - + + print("total eval time:", time.time() - st) \ No newline at end of file diff --git a/q2_test_tune.sh b/q2_test_tune.sh new file mode 100644 index 00000000..c0945098 --- /dev/null +++ b/q2_test_tune.sh @@ -0,0 +1,19 @@ +for model_name in "Qwen2.5-7B-Instruct" "falcon-three-7b" "Meta-Llama-3.1-8B-Instruct" "phi-4"; do +device=5 +format=fake +CUDA_VISIBLE_DEVICES=$device python -m auto_round \ + --format ${format} \ + --data_type int_asym_dq \ + --group_size 16 \ + --super_bits 4 \ + --act_bits 16 \ + --super_group_size 16 \ + --bits 2 \ + --iters 200 \ + --asym \ + --model /models/${model_name} \ + --output_dir /data5/shiqi/${format}_q2_k_s_${model_name}_no_tune \ + --eval_bs 16 \ + --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ + 2>&1 | tee /data5/shiqi/log/gguf_test/${format}_q2_k_s_${model_name}_no_tune.log +done \ No newline at end of file