diff --git a/torchao/quantization/GPTQ/GPTQ.py b/torchao/quantization/GPTQ/GPTQ.py index fe55ed19db..f20a5a4965 100644 --- a/torchao/quantization/GPTQ/GPTQ.py +++ b/torchao/quantization/GPTQ/GPTQ.py @@ -295,7 +295,7 @@ def __torch_function__( SQNR(DQ, DQ_from_qtensor), ) - qparams2 = cls.get_qparams_func(W) + qparams2 = cls.get_qparams_func(W, W.dtype) Q2 = cls.quantize_func(W, qparams2) DQ2 = cls.dequantize_func(Q2, qparams2).to(W.dtype) old_q_out = ( @@ -444,7 +444,9 @@ def faster_quant(cls, H, W, device): group_end = min(group_start + group_size, columns) if group_start % group_size == 0: # needed for when group_size == columns so only calculate qparams once - cur_qparams = cls.get_qparams_func(W[:, group_start:group_end]) + cur_qparams = cls.get_qparams_func( + W[:, group_start:group_end], orig_dtype + ) all_qparams.append(cur_qparams) for index in range(group_start, group_end): # within each group @@ -679,10 +681,11 @@ def __init__( else: self.zero_point_domain = ZeroPointDomain.FLOAT - self.get_qparams_func = lambda w: get_groupwise_affine_qparams( + self.get_qparams_func = lambda w, precision: get_groupwise_affine_qparams( w, n_bit, group_size, + dtype=precision, zero_point_domain=self.zero_point_domain, ) self.quantize_func = (