Skip to content

Align scale dtype with model precision in GPTQ #2403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions torchao/quantization/GPTQ/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __torch_function__(
SQNR(DQ, DQ_from_qtensor),
)

qparams2 = cls.get_qparams_func(W)
qparams2 = cls.get_qparams_func(W, W.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@xiaowangintel xiaowangintel Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. For gptq algorithm, the original implement for weight scales calculation does not specify scale dtype, and use default dtype=bfloat16 of get_groupwise_affine_qparams function. And it's not suitable for other model-precision. We want to align the scale data type with linear weight primitive type. Therefore, the modification passes a dtype parameter to get_groupwise_affine_qparams.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean it's required for gptq algorithm to use hardcoded bfloat16 by default?

And it's not suitable for other model-precision.

what does this mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to enable float16 precision on gptq on intel GPU. But, we encountered an issue of data type misalignment, got bfloat16 scales of weight quantization following original implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pr aims to align scales dtype of weight quantization with original weight dtype.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this work for you:

def get_groupwise_affine_qparams(
    w,
    n_bit=4,
    groupsize=128,
    dtype=None,
    zero_point_domain=ZeroPointDomain.FLOAT,
    preserve_zero=False,
    eps=None,
):
   if dtype is None:
       dtype = w.dtype
   ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the reasons given below.

Q2 = cls.quantize_func(W, qparams2)
DQ2 = cls.dequantize_func(Q2, qparams2).to(W.dtype)
old_q_out = (
Expand Down Expand Up @@ -444,7 +444,9 @@ def faster_quant(cls, H, W, device):
group_end = min(group_start + group_size, columns)
if group_start % group_size == 0:
# needed for when group_size == columns so only calculate qparams once
cur_qparams = cls.get_qparams_func(W[:, group_start:group_end])
cur_qparams = cls.get_qparams_func(
W[:, group_start:group_end], orig_dtype
Copy link
Contributor

@jerryzh168 jerryzh168 Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will orig_dtype here be different from W.dtype

Edit:
seems like not, orig_dtype = W.dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are different. Reading

W = W.detach().float()
, model weight are converted to float in gptq quantization calibration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see

)
all_qparams.append(cur_qparams)

for index in range(group_start, group_end): # within each group
Expand Down Expand Up @@ -679,10 +681,11 @@ def __init__(
else:
self.zero_point_domain = ZeroPointDomain.FLOAT

self.get_qparams_func = lambda w: get_groupwise_affine_qparams(
self.get_qparams_func = lambda w, precision: get_groupwise_affine_qparams(
w,
n_bit,
group_size,
dtype=precision,
zero_point_domain=self.zero_point_domain,
)
self.quantize_func = (
Expand Down
Loading