diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 39605364..818b9893 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -15,7 +15,8 @@ import torch from .utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad from auto_round.data_type.register import register_dtype - +import numpy as np +from concurrent.futures import ProcessPoolExecutor @register_dtype("int_sym") def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, @@ -62,7 +63,6 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp - ## the values should be positive def double_quant_tensor(tensor, bits, q_scale_thresh): maxq = 2 ** bits - 1 @@ -72,7 +72,6 @@ def double_quant_tensor(tensor, bits, q_scale_thresh): qdq_tensor = torch.clamp(round_ste(tensor / scale), max=maxq) * scale return qdq_tensor, scale - @register_dtype("int_asym_dq") def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, super_group_size=8, super_bits=6, @@ -109,20 +108,22 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ else: wmin = wmin_tmp wmax = wmax_tmp - scale = ((wmax - wmin) / maxq).to(scale_dtype) + + # scale_old = ((wmax - wmin) / maxq).to(scale_dtype) + scale,wmin_m = quant_tensor_k_quant_cuda(tensor,num_bits=bits, group_size=group_size) + scale = scale.squeeze(-1) + scale = torch.from_numpy(scale).to(tensor.dtype).cuda() + wmin_m = torch.from_numpy(wmin_m).to(tensor.dtype).cuda() scale = torch.clamp(scale, min=q_scale_thresh) + wmin_m = torch.clamp(wmin_m, min=q_scale_thresh) scale = scale.view(-1, super_group_size) - wmin_m = -wmin # pylint: disable=E1130 wmin_m = wmin_m.view(-1, super_group_size) - - ##conduct double quant + #conduct double quant scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh) wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh) - scale = scale.view(-1, 1) scale = torch.clamp(scale, q_scale_thresh) wmin_m = wmin_m.view(-1, 1) - int_w = round_ste((tensor + wmin_m) / scale + v) q = torch.clamp(int_w, 0, maxq) qdq_result = (scale * q - wmin_m).to(tensor.dtype) @@ -130,6 +131,88 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ # zp = round_ste(wmin_m / scale) # remove this later return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} +def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + Ref: https://github.com/intel/neural-compressor/pull/2169/files + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + try: + import cupy as cp + import torch + + if torch.cuda.is_available(): + data = data.to(torch.float64) + data = cp.asarray(data) + data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) + weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) + rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = cp.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = cp.array(mad) + best_mad_1 = cp.array(best_mad) + # idx_to_replace = cp.where((mad_1 < best_mad_1) & (D > 0))[0] + idx_to_replace = cp.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + scale = scale.astype(cp.float64) + rmin = rmin.astype(cp.float64) + return scale.get(),-rmin.get() + else: + logger.warning( + "Try to use k-quant quantization on CUDA. However, CUDA is not available." + "Fall back to k-quant quantization on CPU." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + except ImportError: + logger.info( + "Now we are using k-quant quantization on cpu, which is time consuming." + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" + "Please also install torch to check CUDA availability." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) @register_dtype("int_asym") def quant_tensor_asym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, @@ -230,7 +313,6 @@ def quant_tensor_sym_gptq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp - def quant_tensor_asym_wo_round(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): diff --git a/auto_round/export/export_to_gguf/quant.py b/auto_round/export/export_to_gguf/quant.py index ca75ad4f..c09d39bb 100644 --- a/auto_round/export/export_to_gguf/quant.py +++ b/auto_round/export/export_to_gguf/quant.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - + import numpy as np from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor - + QK_K = 256 K_SCALE_SIZE = 12 GGML_QUANT_SIZES = { @@ -26,32 +26,32 @@ "q2_k": (256, 2 + 2 + QK_K//16 + QK_K//4), "q8_0": (32, 2 + 32) } - + GGML_QUANT_BLOCK = {} - - + + def register_block(name): - + def register(cls): GGML_QUANT_BLOCK[name] = cls return cls - + return register - - + + def ggml_quant(data: np.array, ggml_type, scale=None, zp=None, wmin_m=None, d_scale=None, d_wmin_m=None): block_size, type_size = GGML_QUANT_SIZES[ggml_type] - + data = data.astype(np.float32, copy=False) shape = data.shape n_blocks = data.size // block_size blocks = data.reshape((n_blocks, block_size)) - + if ggml_type.endswith("_k"): worker = 16 else: worker = 0 - + if worker > 0: n_groups = (data.shape[0] // worker) or 1 blocks = np.array_split(blocks, n_groups, axis=0) @@ -60,7 +60,7 @@ def ggml_quant(data: np.array, ggml_type, scale=None, zp=None, wmin_m=None, d_sc wmin_m = np.array_split(wmin_m, n_groups, axis=0) if wmin_m is not None else [None] * n_groups d_scale = np.array_split(d_scale, n_groups, axis=0) if d_scale is not None else [None] * n_groups d_wmin_m = np.array_split(d_wmin_m, n_groups, axis=0) if d_wmin_m is not None else [None] * n_groups - + quant_func = GGML_QUANT_BLOCK[ggml_type] if ggml_type.endswith("_k"): with ProcessPoolExecutor(worker) as executor: @@ -76,20 +76,20 @@ def ggml_quant(data: np.array, ggml_type, scale=None, zp=None, wmin_m=None, d_sc new_data = quant_func(blocks, scale, zp, wmin_m=wmin_m, d_scale=d_scale, d_wmin_m=d_wmin_m) else: new_data = quant_func(blocks, scale, zp) - + assert new_data.dtype == np.uint8 assert new_data.shape[-1] == type_size new_data = new_data.reshape(*shape[:-1], shape[-1] // block_size * type_size) return new_data - - + + def np_roundf(n: np.ndarray) -> np.ndarray: a = abs(n) floored = np.floor(a) b = floored + np.floor(2 * (a-floored)) return np.sign(n) * b - - + + @register_block("bf16") def bf16_quant_block(blocks: np.array, scale=None, zp=None): n = blocks.view(np.uint32) @@ -98,8 +98,8 @@ def bf16_quant_block(blocks: np.array, scale=None, zp=None): # round to nearest even n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16 return n.astype(np.uint16).view(np.uint8) - - + + @register_block("q4_0") def q4_0_quant_block(blocks: np.array, scale=None, zp=None): if scale is not None: @@ -110,20 +110,20 @@ def q4_0_quant_block(blocks: np.array, scale=None, zp=None): d = max / -8 with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) - + qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15) - + n_blocks = blocks.shape[0] block_size = GGML_QUANT_SIZES["q4_0"][0] qs = qs.reshape((n_blocks, 2, block_size // 2)) qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) - + d = d.astype(np.float16).view(np.uint8) - + return np.concatenate([d, qs], axis=-1) - - + + @register_block("q4_1") def q4_1_quant_block(blocks: np.array, scale=None, zp=None): if scale is not None: @@ -135,19 +135,19 @@ def q4_1_quant_block(blocks: np.array, scale=None, zp=None): d = (max-min) / 15 with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) - + qs = np.trunc((blocks-min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15) - + n_blocks = blocks.shape[0] block_size = GGML_QUANT_SIZES["q4_1"][0] qs = qs.reshape((n_blocks, 2, block_size // 2)) qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) - + d = d.astype(np.float16).view(np.uint8) m = min.astype(np.float16).view(np.uint8) return np.concatenate([d, m, qs], axis=-1) - - + + @register_block("q8_0") def q8_0_quant_block(blocks: np.array, scale=None, zp=None) -> np.ndarray: if scale is not None: @@ -157,74 +157,138 @@ def q8_0_quant_block(blocks: np.array, scale=None, zp=None) -> np.ndarray: with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) qs = np_roundf(blocks * id) - + # (n_blocks, 2) d = d.astype(np.float16).view(np.uint8) # (n_blocks, block_size) qs = qs.astype(np.int8).view(np.uint8) - + return np.concatenate([d, qs], axis=1) - - + +def make_qkx2_quants_multi(data, weight, nmax, group_size, rmin=-1, rdelta=0.1, nstep=20, use_mad=False): + # output_values + bs, group_size = data.shape + + group_min = np.min(data, axis=-1) + group_max = np.max(data, axis=-1) + + the_mins = -group_min + + sum_w = np.sum(weight, axis=-1) + sum_x = np.sum(weight * data, axis=-1) + + group_min[group_min > 0] = 0 + + # if group_min == group_max: + # L = np.zeros(group_size, dtype=np.uint8) + # the_min = -group_min + # return 0.0, L, the_min + + scale = (group_max - group_min) / nmax + iscale = np.where(scale == 0, 0, 1 /scale) + + scale = scale.reshape(-1, 1) + iscale = iscale.reshape(-1, 1) + group_min = group_min.reshape(-1, 1) + group_max = group_max.reshape(-1, 1) + + l_values = np.round(iscale * (data - group_min)) + L = np.clip(l_values, 0, nmax).astype(np.uint8) + + diffs = scale * L + group_min - data + diffs = np.abs(diffs) if use_mad else diffs**2 + best_mad = np.sum(weight * diffs) + + if nstep < 1: + return scale, L, the_mins + + for step in range(nstep): + iscale = (rmin + rdelta * step + nmax) / (group_max - group_min) + l_values = np.round(iscale * (data - group_min)) + Laux = np.clip(l_values, 0, nmax).astype(np.uint8) + + sum_l = np.sum(weight * Laux) + sum_l2 = np.sum(weight * Laux**2) + sum_xl = np.sum(weight * Laux * data) + + D = sum_w * sum_l2 - sum_l * sum_l + if D > 0: + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D + if this_min > 0: + this_min = 0 + this_scale = sum_xl / sum_l2 + + diffs = this_scale * Laux + this_min - data + diffs = np.abs(diffs) if use_mad else diffs**2 + mad = np.sum(weight * diffs) + + if mad < best_mad: + L = Laux.copy() + best_mad = mad + scale = this_scale + group_min = this_min + + the_min = -group_min + return scale, L, the_min + def make_qkx2_quants(data, weight, nmax, group_size, rmin=-1, rdelta=0.1, nstep=20, use_mad=False): group_min = np.min(data) group_max = np.max(data) - sum_w = np.sum(weight) sum_x = np.sum(weight * data) - - if group_min > 0: - group_min = 0 + + group_min = min(group_min, 0) if group_min == group_max: L = np.zeros(group_size, dtype=np.uint8) the_min = -group_min return 0.0, L, the_min - - iscale = nmax / (group_max-group_min) + + iscale = nmax / (group_max - group_min) scale = 1 / iscale - L = np.zeros(group_size, dtype=np.uint8) - - l_values = np.round(iscale * (data-group_max)) + + l_values = np.round(iscale * (data-group_min)) L = np.clip(l_values, 0, nmax).astype(np.uint8) - diffs = scale*L + group_min - data + + diffs = scale * L + group_min - data diffs = np.abs(diffs) if use_mad else diffs**2 - best_mad = np.sum(weight * (diffs)) - + best_mad = np.sum(weight * diffs) + if nstep < 1: the_min = -group_min return scale, L, the_min - - Laux = [] + for step in range(nstep): - iscale = (rmin + rdelta*step + nmax) / (group_max-group_min) - l_values = np.round(iscale * (data-group_min)) + iscale = (rmin + rdelta * step + nmax) / (group_max - group_min) + l_values = np.round(iscale * (data - group_min)) Laux = np.clip(l_values, 0, nmax).astype(np.uint8) - + sum_l = np.sum(weight * Laux) sum_l2 = np.sum(weight * Laux**2) sum_xl = np.sum(weight * Laux * data) - - D = sum_w*sum_l2 - sum_l*sum_l + + D = sum_w * sum_l2 - sum_l * sum_l if D > 0: - this_scale = (sum_w*sum_xl - sum_x*sum_l) / D - this_min = (sum_l2*sum_x - sum_l*sum_xl) / D + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D if this_min > 0: this_min = 0 this_scale = sum_xl / sum_l2 - - diffs = this_scale*Laux + this_min - data - mad = np.sum(weight * diffs**2) - + + diffs = this_scale * Laux + this_min - data + diffs = np.abs(diffs) if use_mad else diffs**2 + mad = np.sum(weight * diffs) + if mad < best_mad: L = Laux.copy() best_mad = mad scale = this_scale group_min = this_min - + the_min = -group_min return scale, L, the_min - - + + @register_block("q2_k") def q2_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale=None, d_wmin_m=None): nb = blocks.shape[0] @@ -232,12 +296,12 @@ def q2_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale output_d = np.empty(nb, dtype=np.float32) output_dmin = np.empty(nb, dtype=np.float32) output_qs = np.empty((nb, QK_K // 16 // 4, 16), dtype=np.uint8) - + blocks = blocks.reshape((nb, QK_K // 16, 16)) weight = np.abs(blocks) scales = np.empty(QK_K // 16, dtype=np.float32) mins = np.empty((QK_K // 16), dtype=np.float32) - + if scale is not None: scale = scale.reshape((-1, QK_K // 16)) wmin_m = wmin_m.reshape((-1, QK_K // 16)) @@ -245,11 +309,11 @@ def q2_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale output_dmin = d_wmin_m.astype(np.float32) inv_scales = np.where(d_scale == 0, 0, 1 / output_d) inv_mins = np.where(d_wmin_m == 0, 0, 1 / output_dmin) - + # for i in tqdm(range(nb), desc="packing layer"): for i in range(nb): all_L = np.empty(blocks[i].shape, dtype=np.uint8) - + if scale is not None: scales = scale[i] mins = wmin_m[i] @@ -265,12 +329,12 @@ def q2_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale all_L[j] = tmp_l scales[j] = tmp_scale mins[j] = the_min - + max_scale = max(scales) max_min = max(mins) inv_scale = 15. / max_scale inv_min = 15. / max_min - + if max_scale > 0: output_scale[i] = np.round(inv_scale * scales).astype(np.uint8) output_d[i] = max_scale / 15. @@ -282,7 +346,7 @@ def q2_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale output_dmin[i] = max_min / 15. else: output_dmin[i] = 0. - + d_tmp = output_d[i] * (output_scale[i] & 0xF) dm_tmp = output_dmin[i] * (output_scale[i] >> 4) for j in range(QK_K // 16): @@ -292,15 +356,15 @@ def q2_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale all_L[j] = np.round((blocks[i][j] + dm_tmp[j]) / d_tmp[j]).astype(np.uint8) all_L = np.clip(all_L, 0, 3) output_qs[i] = all_L[::4] | (all_L[1::4] << 2) | (all_L[2::4] << 4) | (all_L[3::4] << 6) - + output_d = output_d.reshape(-1, 1).astype(np.float16).view(np.uint8) output_dmin = output_dmin.reshape(-1, 1).astype(np.float16).view(np.uint8) output_qs = output_qs.reshape((nb, QK_K // 4)) - + # [scale, qs, d, dmin] return np.concatenate([output_scale, output_qs, output_d, output_dmin], axis=-1) - - + + @register_block("q4_k") def q4_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale=None, d_wmin_m=None): nb = blocks.shape[0] @@ -308,14 +372,14 @@ def q4_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale output_d = np.empty(nb, dtype=np.float32) output_dmin = np.empty(nb, dtype=np.float32) output_qs = np.empty((nb, QK_K // 64, 32), dtype=np.uint8) - + blocks = blocks.reshape((nb, QK_K // 32, 32)) sum_x2 = np.sum(np.power(blocks, 2), axis=-1) av_x = np.sqrt(sum_x2 / 32) - weight = blocks + av_x.reshape((*av_x.shape, 1)) + weight = np.abs(blocks) + av_x.reshape((*av_x.shape, 1)) scales = np.empty(QK_K // 32, dtype=np.float32) mins = np.empty(QK_K // 32, dtype=np.float32) - + if scale is not None: scale = scale.reshape(-1, QK_K // 32) wmin_m = wmin_m.reshape(-1, QK_K // 32) @@ -323,11 +387,11 @@ def q4_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale output_dmin = d_wmin_m.astype(np.float32) inv_scales = np.where(d_scale == 0, 0, 1 / output_d) inv_mins = np.where(d_wmin_m == 0, 0, 1 / output_dmin) - + # for i in tqdm(range(nb), desc="packing layer"): for i in range(nb): all_L = np.empty(blocks[i].shape, dtype=np.uint8) - + if scale is not None: scales = scale[i] mins = wmin_m[i] @@ -339,46 +403,71 @@ def q4_k_quant_block(blocks: np.array, scale=None, zp=None, wmin_m=None, d_scale else: for j in range(QK_K // 32): tmp_scale, tmp_l, the_min = make_qkx2_quants( - blocks[i][j], weight[i][j], nmax=15, group_size=32, rmin=-1, rdelta=0.1, nstep=20) + blocks[i][j], weight[i][j], nmax=15, group_size=32, rmin=-1, rdelta=0.1, nstep=20, use_mad=False) all_L[j] = tmp_l scales[j] = tmp_scale mins[j] = the_min - + max_scale = max(scales) max_min = max(mins) inv_scale = 63. / max_scale if max_scale > 0 else 0. inv_min = 63. / max_min if max_min > 0 else 0. - + ls = np.round(inv_scale * scales).astype(np.uint8) lm = np.round(inv_min * mins).astype(np.uint8) - + output_scale[i][:4] = ls[:4] output_scale[i][4:8] = lm[:4] - + output_scale[i][8:] = (ls[4:] & 0xF) | ((lm[4:] & 0xF) << 4) output_scale[i][:4] |= ((ls[4:] >> 4) << 6) output_scale[i][4:8] |= ((lm[4:] >> 4) << 6) - + if d_scale is None: output_d[i] = max_scale / 63 if d_wmin_m is None: output_dmin[i] = max_min / 63 - + d_tmp = output_d[i] * ls dm_tmp = output_dmin[i] * lm - + for j in range(d_tmp.size): if d_tmp[j] == 0.: continue else: all_L[j] = np.round((blocks[i][j] + dm_tmp[j]) / d_tmp[j]).astype(np.uint8) - + all_L = np.clip(all_L, 0, 15) output_qs[i] = all_L[::2] | (all_L[1::2] << 4) - + output_d = output_d.reshape(-1, 1).astype(np.float16).view(np.uint8) output_dmin = output_dmin.reshape(-1, 1).astype(np.float16).view(np.uint8) output_qs = output_qs.reshape(nb, QK_K // 2) - + # [d, dmin, scale, qs] return np.concatenate([output_d, output_dmin, output_scale, output_qs], axis=-1) + + +if __name__ == "__main__": + from transformers import set_seed + set_seed(42) + data = np.random.randn(2, 32) + print(data) + scale, L, the_min = make_qkx2_quants(data[0], data[0], 15, 32, nstep=20) + print(scale) + print(L) + print(the_min) + print("--" * 20) + + scale, L, the_min = make_qkx2_quants(data[1], data[1], 15, 32, nstep=20) + print(scale) + print(L) + print(the_min) + print("--" * 20) + + scale, L, the_min = make_qkx2_quants_multi(data, data, 15, 32, nstep=20) + print(scale) + print(L) + print(the_min) + + \ No newline at end of file diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index b3685d9f..36a9af3f 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - + # Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,23 +35,23 @@ clear_memory, get_device_and_parallelism, set_cuda_visible_devices) - + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - - + + class BasicArgumentParser(argparse.ArgumentParser): - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_argument( "--model", "--model_name", "--model_name_or_path", default="facebook/opt-125m", help="model name or path") - + self.add_argument('--eval', action='store_true', help="whether to use eval only mode") - + self.add_argument("--bits", default=4, type=int, help="number of weight bits") - + self.add_argument("--eval_bs", default=None, type=int, help="batch size in evaluation") - + self.add_argument( "--device", "--devices", @@ -62,38 +62,38 @@ def __init__(self, *args, **kwargs): "The default is set to cuda:0," "allowing for automatic detection and switch to HPU or CPU." "set --device 0,1,2 to use multiple cards.") - + self.add_argument("--asym", action='store_true', help="whether to use asym quantization") - + self.add_argument( "--dataset", default="NeelNanda/pile-10k", type=str, help="the dataset for quantization training") - + self.add_argument( "--minmax_lr", default=None, type=float, help="minmax learning rate, if None, it will beset to be the same with lr") - + self.add_argument("--seed", default=42, type=int, help="random seed") - + self.add_argument("--adam", action='store_true', help="whether to use adam optimizer instead of SignSGD") - + self.add_argument("--gradient_accumulate_steps", default=1, type=int, help="gradient accumulate steps") - + self.add_argument("--nblocks", default=1, type=int, help="how many blocks to tune together") - + self.add_argument("--low_gpu_mem_usage", action='store_true', help="offload intermediate features to cpu") - + self.add_argument("--format", default="auto_round", type=str, help="the format to save the model") - + self.add_argument("--data_type", "--dtype", default='int', help="data type for tuning, 'int', 'mx_fp' and etc") - + self.add_argument( "--scale_dtype", default='fp16', choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], help="scale data type to use for quantization") - + self.add_argument( "--tasks", "--task", @@ -102,35 +102,35 @@ def __init__(self, *args, **kwargs): "openbookqa,boolq,arc_easy,arc_challenge", default=None, help="lm-eval tasks") - + self.add_argument( "--output_dir", default="./tmp_autoround", type=str, help="the directory to save quantized model") - + self.add_argument("--disable_eval", action='store_true', help="whether to disable lm-eval evaluation after tuning") - + self.add_argument( "--eval_task_by_task", action="store_true", help="whether to eval task by task.") - + self.add_argument("--disable_amp", action='store_true', help="disable amp") - + self.add_argument( "--disable_minmax_tuning", action='store_true', help="whether to disable enable weight minmax tuning") - + self.add_argument("--enable_norm_bias_tuning", action='store_true', help="whether to enable norm bias tuning") - + self.add_argument( "--disable_trust_remote_code", action='store_true', help="whether to disable trust_remote_code") - + self.add_argument( "--disable_quanted_input", action='store_true', help="whether to disuse the output of quantized block to tune the next block") - + self.add_argument("--quant_lm_head", action='store_true', help="whether to quant lm_head") - + self.add_argument( "--low_cpu_mem_mode", default=0, @@ -143,58 +143,58 @@ def __init__(self, *args, **kwargs): "2 means choose layer-wise mode, load the weights of each layer from disk when tuning," " minimum memory consumption and also slowest running speed." "others means not use low cpu memory. Default to 0, not use low cpu memory.") - + self.add_argument( "--low_cpu_mem_tmp_dir", default=None, type=str, help="temporary work space to store the temporary files " "when using low cpu memory mode. Will remove after tuning.") - + self.add_argument( "--model_dtype", default=None, type=str, choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], help="force to convert the dtype, some backends supports fp16 dtype better") - + self.add_argument("--act_bits", default=16, type=int, help="activation bits") - + self.add_argument( "--fp_layers", default="", type=str, help="list of Layer names to maintain original data type") - + self.add_argument( "--not_use_best_mse", action='store_true', help="whether to use the iter of best mes loss in the tuning phase") - + self.add_argument( "--to_quant_block_names", default=None, type=str, help="Names of quantitative blocks, please use commas to separate them.") - + self.add_argument("--enable_torch_compile", action='store_true', help="whether to enable torch compile") - + self.add_argument("--act_data_type", "--act_dtype", default=None, type=str, help="activation data type") - + self.add_argument("--disable_act_dynamic", action='store_true', help="activation static quantization") - + self.add_argument("--disable_deterministic_algorithms", action='store_true', help="disable torch deterministic algorithms.") - + self.add_argument("--device_map", default=None, type=str, help="device_map for block in tuning phase") - + self.add_argument( "--super_group_size", default=None, type=int, help="the number of super group size when use double quant.") - + self.add_argument( "--super_bits", default=None, type=int, help="number of scale and mins quant bits for double quant.") - - + + class EvalArgumentParser(argparse.ArgumentParser): - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_argument( @@ -209,7 +209,7 @@ def __init__(self, *args, **kwargs): "The default is set to cuda:0," "allowing for automatic detection and switch to HPU or CPU." "set --device 0,1,2 to use multiple cards.") - + self.add_argument("--tasks", "--task", default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \ "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge", @@ -218,148 +218,148 @@ def __init__(self, *args, **kwargs): "--disable_trust_remote_code", action='store_true', help="whether to disable trust_remote_code") self.add_argument("--eval_bs", "--bs", "--batch_size", default=None, type=int, help="batch size in evaluation") self.add_argument("--eval_task_by_task", action='store_true', help="whether to eval task by task.") - - + + def setup_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int, help="train batch size") - + parser.add_argument("--iters", "--iter", default=200, type=int, help="iteration to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=2048, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=128, type=int, help="number of samples") - + parser.add_argument( "--lr", default=None, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() return args - - + + def setup_best_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int, help="train batch size") - + parser.add_argument("--iters", "--iter", default=1000, type=int, help="iterations to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=2048, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=512, type=int, help="number of samples") - + parser.add_argument( "--lr", default=None, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() args.low_gpu_mem_usage = True - + return args - - + + def setup_light_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int, help="train batch size") - + parser.add_argument("--iters", "--iter", default=50, type=int, help="iterations to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=2048, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=128, type=int, help="number of samples") - + parser.add_argument( "--lr", default=5e-3, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() - + return args - - + + def setup_fast_parser(): parser = BasicArgumentParser() - + parser.add_argument("--group_size", default=128, type=int, help="group size") - + parser.add_argument("--batch_size", "--train_bs", "--bs", default=4, type=int, help="train batch size") - + parser.add_argument("--iters", default=200, type=int, help="iterations to tune each block") - + parser.add_argument( "--seqlen", "--seq_len", default=512, type=int, help="sequence length of the calibration samples") - + parser.add_argument("--nsamples", "--nsample", default=128, type=int, help="number of samples") - + parser.add_argument( "--lr", default=None, type=float, help="learning rate, if None, it will be set to 1.0/iters automatically") - + args = parser.parse_args() - + return args - - + + def setup_eval_parser(): parser = EvalArgumentParser() args = parser.parse_args() return args - - + + def tune(args): if args.disable_eval: logging.warning("`disable_eval` is deprecated and is now set by default.") - + if args.eval_bs is None: args.eval_bs = "auto" - + import transformers - + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig - + from auto_round.utils import detect_device, get_library_version from auto_round.utils import logger, _gguf_args_check - + if args.format is None: args.format = "auto_round" - + formats = args.format.lower().replace(' ', '').split(",") from auto_round.utils import supported_formats for format in formats: if format not in supported_formats: raise ValueError(f"{format} is not supported, we only support {supported_formats}") - + args = _gguf_args_check(args) - + if "auto_gptq" in args.format and args.asym is True: logger.warning("the auto_gptq kernel has issues with asymmetric quantization. " "It is recommended to use sym quantization or --format='auto_round'") - + if "marlin" in args.format and args.asym is True: assert False, "marlin backend only supports sym quantization, please remove --asym" - + ##must set this before import torch set_cuda_visible_devices(args.device) device_str, use_auto_mapping = get_device_and_parallelism(args.device) - + import torch if not args.disable_deterministic_algorithms: torch.use_deterministic_algorithms(True, warn_only=True) # logger.info("`torch.use_deterministic_algorithms` is enabled by default for reproducibility " # "and can be disabled using the `--disable_deterministic_algorithms` argument.") - + if args.enable_torch_compile: logger.info("`torch.compile` is enabled to reduce tuning costs. " "If it causes issues, you can disable it by remove `--enable_torch_compile` argument.") - + model_name = args.model if model_name[-1] == "/": model_name = model_name[:-1] @@ -367,7 +367,7 @@ def tune(args): torch_dtype = "auto" if device_str is not None and "hpu" in device_str: torch_dtype = torch.bfloat16 - + from auto_round.utils import llm_load_model model, tokenizer, low_cpu_mem_usage = llm_load_model( model_name, @@ -378,25 +378,25 @@ def tune(args): low_cpu_mem_mode=args.low_cpu_mem_mode, low_cpu_mem_tmp_dir=args.low_cpu_mem_tmp_dir, model_dtype=args.model_dtype) - + from auto_round import AutoRound, AutoRoundAdam - + seqlen = args.seqlen - + if hasattr(tokenizer, "model_max_length"): if tokenizer.model_max_length < seqlen: logger.info( f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length") seqlen = min(seqlen, tokenizer.model_max_length) args.seqlen = seqlen - + if "bloom" in model_name: args.low_gpu_mem_usage = False - + round = AutoRound if args.adam: round = AutoRoundAdam - + layer_config = {} for n, m in model.named_modules(): if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D): @@ -405,7 +405,7 @@ def tune(args): logger.info( f"{n} will not be quantized due to its shape not being divisible by 32," " resulting in an exporting issue to autogptq") - + not_quantize_layer_names = get_fp_layer_names(model, args.fp_layers) for name in not_quantize_layer_names: layer_config[name] = {"bits": 16} @@ -415,7 +415,7 @@ def tune(args): if "auto_round" not in format and "fake" not in format and "awq" not in format: ##TODO gptq could support some mixed precision config logger.warning(f"mixed precision exporting does not support {format} currently") - + lm_head_layer_name = "lm_head" for n, _ in model.named_modules(): lm_head_layer_name = n @@ -430,7 +430,7 @@ def tune(args): f"reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been " f"supported currently") break - + if args.quant_lm_head: layer_config[lm_head_layer_name] = {"bits": args.bits} for format in formats: @@ -438,16 +438,16 @@ def tune(args): auto_round_formats = [s for s in supported_formats if s.startswith("auto_round")] raise ValueError( f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}") - + if "auto_awq" in args.format: from auto_round.utils import check_awq_gemm_compatibility awq_supported, info = check_awq_gemm_compatibility( model, args.bits, args.group_size, not args.asym, layer_config) if not awq_supported: logger.warning(f"The AutoAWQ format may not be supported due to {info}") - + enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False - + autoround = round( model, tokenizer, @@ -484,59 +484,60 @@ def tune(args): super_group_size=args.super_group_size, super_bits=args.super_bits, ) - + model_name = args.model.rstrip("/") if model_name.split('/')[-1].strip('.') == "": export_dir = os.path.join(args.output_dir, f"w{args.bits}g{args.group_size}") else: export_dir = os.path.join(args.output_dir, model_name.split('/')[-1] + f"-w{args.bits}g{args.group_size}") - + model, folders = autoround.quantize_and_save(export_dir, format=args.format) - + if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2: import shutil shutil.rmtree(args.low_cpu_mem_tmp_dir, ignore_errors=True) - + model.eval() clear_memory() - + lm_eval_version = get_library_version("lm-eval") - + eval_folder = folders[-1] if args.tasks is None or args.tasks == "" or eval_folder is None: return - + tasks = args.tasks if isinstance(tasks, str): tasks = tasks.split(',') - + from lm_eval.utils import make_table # pylint: disable=E0401 - + logger.info(f"Using lm-eval version {lm_eval_version}") eval_gguf_model = False for file in os.listdir(eval_folder): if file.endswith("gguf"): eval_gguf_model = True break - + if args.act_bits <= 8 or eval_gguf_model: if eval_gguf_model: # gguf floder only contains one file for file in os.listdir(eval_folder): gguf_file = file user_model = AutoModelForCausalLM.from_pretrained( - eval_folder, gguf_file=gguf_file, device_map="auto" if use_auto_mapping else None) + eval_folder, gguf_file=gguf_file, device_map="auto") + user_model = user_model.to(torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(eval_folder, gguf_file=gguf_file) else: if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: from accelerate.big_modeling import dispatch_model - + dispatch_model(model, model.hf_device_map) user_model = model else: device_str = detect_device(device_str) user_model = model.to(device_str) - + if args.eval_task_by_task: eval_task_by_task( user_model, tokenizer=tokenizer, device=device_str, tasks=args.tasks, batch_size=args.eval_bs) @@ -558,8 +559,8 @@ def tune(args): res = simple_evaluate( model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs) print(make_table(res)) - - + + def _eval_init(tasks, model_path, device, disable_trust_remote_code=False): set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) @@ -569,35 +570,57 @@ def _eval_init(tasks, model_path, device, disable_trust_remote_code=False): if isinstance(tasks, str): tasks = tasks.split(',') return tasks, model_args, device_str - - + + def eval(args): tasks, model_args, device_str = _eval_init(args.tasks, args.model, args.device, args.disable_trust_remote_code) - + # load after _eval_int in order to make sure import torch after set CUDA_VISBILE_DEVICES - from auto_round.eval.evaluation import simple_evaluate - - res = simple_evaluate(model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs) - - from lm_eval.utils import make_table # pylint: disable=E0401 - print(make_table(res)) - - + from auto_round.eval.evaluation import simple_evaluate, simple_evaluate_user_model + + is_gguf_file = False + if os.path.isfile(args.model) and args.model.endswith(".gguf"): + is_gguf_file = True + gguf_file = os.path.basename(args.model) + model = os.path.dirname(args.model) + else: + for file in os.listdir(model): + if file.endswith(".gguf"): + is_gguf_file = True + gguf_file = file + if is_gguf_file: + import torch + from transformers import AutoTokenizer, AutoModelForCausalLM + from lm_eval.utils import make_table # pylint: disable=E0401 + tokenizer = AutoTokenizer.from_pretrained(model, gguf_file=gguf_file) + user_model = AutoModelForCausalLM.from_pretrained(model, gguf_file=gguf_file, device_map="auto") + user_model = user_model.to(torch.bfloat16) + res = simple_evaluate_user_model( + user_model, tokenizer, tasks=tasks, batch_size=args.eval_bs, device=device_str) + print(make_table(res)) + else: + res = simple_evaluate( + model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs) + + from lm_eval.utils import make_table # pylint: disable=E0401 + print(make_table(res)) + + def eval_task_by_task( model, device=None, tasks=None, tokenizer=None, batch_size=None, max_batch_size=64, trust_remote_code=True): set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) - + # load after _eval_int in order to make sure import torch after set CUDA_VISBILE_DEVICES import traceback from auto_round.utils import logger from lm_eval import simple_evaluate as lm_simple_evaluate from lm_eval.models.huggingface import HFLM from transformers import AutoModelForCausalLM, AutoTokenizer - + from auto_round import AutoRoundConfig # pylint: disable=E0611 if batch_size is None: - batch_size = "auto" + batch_size = "auto:8" is_gguf_file = False if not isinstance(model, str): parallelism = False @@ -612,8 +635,11 @@ def eval_task_by_task( is_gguf_file = True gguf_file = file if is_gguf_file: + import torch tokenizer = AutoTokenizer.from_pretrained(model, gguf_file=gguf_file) + # float32 model = AutoModelForCausalLM.from_pretrained(model, gguf_file=gguf_file, device_map="auto") + model = model.to(torch.bfloat16) hflm = HFLM( pretrained=model, tokenizer=tokenizer, @@ -622,13 +648,15 @@ def eval_task_by_task( max_batch_size=max_batch_size, parallelize=parallelism, trust_remote_code=trust_remote_code) - + if isinstance(tasks, str): tasks = tasks.replace(" ", "").split(",") - + from lm_eval.utils import make_table # pylint: disable=E0401 res_all = {} res_keys = ["results", "versions", "n-shot", "higher_is_better"] + import time + st = time.time() for task in tasks: try: res = lm_simple_evaluate(model=hflm, model_args=None, device=device_str, tasks=task, batch_size=batch_size) @@ -646,11 +674,14 @@ def eval_task_by_task( except Exception as e: traceback.print_exc() continue - + if not res_all: res_all = res else: for key in res_keys: res_all[key].update(res[key]) print(make_table(res_all)) - + + print("total eval time:", time.time() - st) + + \ No newline at end of file diff --git a/q2_test_search_no_tune.sh b/q2_test_search_no_tune.sh new file mode 100644 index 00000000..c98d6cbd --- /dev/null +++ b/q2_test_search_no_tune.sh @@ -0,0 +1,19 @@ +for model_name in "Qwen2.5-7B-Instruct" "falcon-three-7b" "Meta-Llama-3.1-8B-Instruct" "phi-4"; do +device=5 +format=fake +CUDA_VISIBLE_DEVICES=$device python -m auto_round \ + --format ${format} \ + --data_type int_asym_dq \ + --group_size 16 \ + --super_bits 4 \ + --act_bits 16 \ + --super_group_size 16 \ + --bits 2 \ + --iters 200 \ + --asym \ + --model /models/${model_name} \ + --output_dir /data5/shiqi/${format}_q2_k_s_${model_name}_search_no_tune \ + --eval_bs 16 \ + --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ + 2>&1 | tee /data5/shiqi/log/gguf_test/${format}_q2_k_s_${model_name}_search_no_tune.log +done \ No newline at end of file diff --git a/q4_test_search_no_tune.sh b/q4_test_search_no_tune.sh new file mode 100644 index 00000000..0cbbb678 --- /dev/null +++ b/q4_test_search_no_tune.sh @@ -0,0 +1,17 @@ +for model_name in "Qwen2.5-7B-Instruct" "falcon-three-7b" "Meta-Llama-3.1-8B-Instruct" "phi-4"; do +device=0 +format=fake +CUDA_VISIBLE_DEVICES=$device python -m auto_round \ + --format fake \ + --data_type int_asym_dq \ + --group_size 32 \ + --super_bits 6 \ + --super_group_size 8 \ + --bits 4 \ + --iters 200 \ + --model /models/${model_name} \ + --output_dir /data5/shiqi/${format}_${model_name}_search \ + --eval_bs 16 \ + --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ + 2>&1 | tee /data5/shiqi/log/gguf_test/${format}_${model_name}_search.log +done \ No newline at end of file diff --git a/test_zero_iter.sh b/test_zero_iter.sh new file mode 100644 index 00000000..f4d9e298 --- /dev/null +++ b/test_zero_iter.sh @@ -0,0 +1,18 @@ +#"Qwen2.5-7B-Instruct" "Meta-Llama-3.1-8B-Instruct" "falcon-three-7b" "phi-4" +for model_name in "Qwen2.5-1.5B-Instruct"; do +device=4 +format=fake +CUDA_VISIBLE_DEVICES=$device python -m auto_round \ + --format fake \ + --data_type int_asym_dq \ + --group_size 32 \ + --super_bits 6 \ + --super_group_size 8 \ + --bits 4 \ + --iters 0 \ + --model /models/${model_name} \ + --output_dir /data5/shiqi/${format}_${model_name}_zero \ + --eval_bs 16 \ + --tasks arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,mmlu,openbookqa,piqa,truthfulqa_mc1,winogrande \ + 2>&1 | tee /data5/shiqi/log/gguf_test/${format}_${model_name}_zero.log +done \ No newline at end of file