Skip to content

Try search #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
102 changes: 92 additions & 10 deletions auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch
from .utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad
from auto_round.data_type.register import register_dtype

import numpy as np
from concurrent.futures import ProcessPoolExecutor

@register_dtype("int_sym")
def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16,
Expand Down Expand Up @@ -62,7 +63,6 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
return qdq_result, scale, zp


## the values should be positive
def double_quant_tensor(tensor, bits, q_scale_thresh):
maxq = 2 ** bits - 1
Expand All @@ -72,7 +72,6 @@ def double_quant_tensor(tensor, bits, q_scale_thresh):
qdq_tensor = torch.clamp(round_ste(tensor / scale), max=maxq) * scale
return qdq_tensor, scale


@register_dtype("int_asym_dq")
def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16,
tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, super_group_size=8, super_bits=6,
Expand Down Expand Up @@ -109,27 +108,111 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_
else:
wmin = wmin_tmp
wmax = wmax_tmp
scale = ((wmax - wmin) / maxq).to(scale_dtype)

# scale_old = ((wmax - wmin) / maxq).to(scale_dtype)
scale,wmin_m = quant_tensor_k_quant_cuda(tensor,num_bits=bits, group_size=group_size)
scale = scale.squeeze(-1)
scale = torch.from_numpy(scale).to(tensor.dtype).cuda()
wmin_m = torch.from_numpy(wmin_m).to(tensor.dtype).cuda()
scale = torch.clamp(scale, min=q_scale_thresh)
wmin_m = torch.clamp(wmin_m, min=q_scale_thresh)
scale = scale.view(-1, super_group_size)
wmin_m = -wmin # pylint: disable=E1130
wmin_m = wmin_m.view(-1, super_group_size)

##conduct double quant
#conduct double quant
scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh)
wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh)

scale = scale.view(-1, 1)
scale = torch.clamp(scale, q_scale_thresh)
wmin_m = wmin_m.view(-1, 1)

int_w = round_ste((tensor + wmin_m) / scale + v)
q = torch.clamp(int_w, 0, maxq)
qdq_result = (scale * q - wmin_m).to(tensor.dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
# zp = round_ste(wmin_m / scale) # remove this later
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m}

def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
"""Quantize tensor per group based on k quant.
Ref: https://github.com/intel/neural-compressor/pull/2169/files
Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
Returns:
output: quantized weight
scale: scale
zero_point: zero point
"""
try:
import cupy as cp
import torch

if torch.cuda.is_available():
data = data.to(torch.float64)
data = cp.asarray(data)
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
maxq = 2**num_bits - 1
minq = 0
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
mask = rmin != rmax
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
scale = 1 / iscale
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
diff = scale * quant_data + rmin - data # (nb, group_size)
best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
nstep = 20
rdelta = 0.1
rrmin = -1
for is_ in range(nstep):
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
mask = rmin != rmax
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
mul_weights_quant_data_new = weights * quant_data_new
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
D = cp.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1)

this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)

diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)

mad_1 = cp.array(mad)
best_mad_1 = cp.array(best_mad)
# idx_to_replace = cp.where((mad_1 < best_mad_1) & (D > 0))[0]
idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this line to idx_to_replace = cp.where((mad_1 < best_mad_1) & (D > 0))[0]

quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
best_mad[idx_to_replace] = mad[idx_to_replace]
scale[idx_to_replace] = this_scale[idx_to_replace]
rmin[idx_to_replace] = this_min[idx_to_replace]

scale = scale.astype(cp.float64)
rmin = rmin.astype(cp.float64)
return scale.get(),-rmin.get()
else:
logger.warning(
"Try to use k-quant quantization on CUDA. However, CUDA is not available."
"Fall back to k-quant quantization on CPU."
)
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
except ImportError:
logger.info(
"Now we are using k-quant quantization on cpu, which is time consuming."
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
"Please also install torch to check CUDA availability."
)
return quant_tensor_k_quant_cpu(data, num_bits, group_size)

@register_dtype("int_asym")
def quant_tensor_asym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16,
Expand Down Expand Up @@ -230,7 +313,6 @@ def quant_tensor_sym_gptq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
return qdq_result, scale, zp


def quant_tensor_asym_wo_round(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0,
scale_dtype=torch.float16,
tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs):
Expand Down
Loading