Skip to content

Commit 0045d88

Browse files
authored
Match QAT prepare and convert numerics exactly for bf16 and fp16 (#2060)
**Summary:** The previous PR #1964 got this to match for fp32, but there were two additional sources of numerical discrepancies with bf16: 1. QAT asymmetric per token choose qparams diverged from `choose_qparams_affine`, which had simpler logic 2. QAT per token fake quantize cast the input to fp32 before fake quantizing them 3. QAT symmetric per group choose qparams used a hardcoded eps value that did not match `choose_qparams_affine` These are both resolved in this commit: (1) QAT now uses `choose_qparams_affine` instead of the custom function for asymmetric per token, which is now deleted, (2) QAT no longer casts the input to fp32, and (3) QAT now uses an eps value that corresponds to the input dtype. The result is exact match in numerics between the prepare and convert steps for both fp32, bf16, and fp16. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
1 parent bd747a2 commit 0045d88

File tree

4 files changed

+59
-75
lines changed

4 files changed

+59
-75
lines changed

test/quantization/test_qat.py

+42-17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn.functional as F
15+
from parameterized import parameterized
1516
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1617

1718
from torchao import quantize_
@@ -40,7 +41,6 @@
4041
Int8DynActInt4WeightQATLinear,
4142
)
4243
from torchao.quantization.qat.utils import (
43-
_choose_qparams_per_token_asymmetric,
4444
_fake_quantize_per_channel_group,
4545
_fake_quantize_per_token,
4646
_GenericFakeQuantize,
@@ -53,12 +53,16 @@
5353
MappingType,
5454
TorchAODType,
5555
ZeroPointDomain,
56+
choose_qparams_affine,
57+
dequantize_affine,
5658
fake_quantize_affine,
59+
quantize_affine,
5760
)
5861
from torchao.quantization.unified import (
5962
TwoStepQuantizer,
6063
)
6164
from torchao.quantization.utils import (
65+
_get_per_token_block_size,
6266
get_group_qparams_symmetric,
6367
get_groupwise_affine_qparams,
6468
groupwise_affine_quantize_tensor,
@@ -134,12 +138,13 @@ def forward(self, x):
134138

135139

136140
class M4(torch.nn.Module):
137-
def __init__(self):
141+
def __init__(self, dtype: torch.dtype = torch.float32):
138142
super().__init__()
139-
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
143+
self.dtype = dtype
144+
self.linear = torch.nn.Linear(512, 256, bias=False).to(dtype)
140145

141146
def example_inputs(self):
142-
return (torch.randn(1, 512).to(torch.float),)
147+
return (torch.randn(1, 512).to(self.dtype),)
143148

144149
def forward(self, x):
145150
return self.linear(x)
@@ -219,30 +224,41 @@ def test_fake_quantize_per_token(self):
219224
torch.manual_seed(self.SEED)
220225
x = torch.randn(100, 256).requires_grad_()
221226
x2 = copy.deepcopy(x)
222-
# TODO: use torch.ops.aten.quantized_decomposed version instead
223-
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
227+
block_size = _get_per_token_block_size(x)
228+
(s, zp) = choose_qparams_affine(
229+
x,
230+
mapping_type=MappingType.ASYMMETRIC,
231+
block_size=block_size,
232+
target_dtype=torch.int8,
233+
quant_min=-128,
234+
quant_max=127,
235+
scale_dtype=torch.float32,
236+
zero_point_dtype=torch.int32,
237+
)
224238

225239
# fake quant op
226240
out = _fake_quantize_per_token(x, s, zp, qmin, qmax)
227241
out.sum().backward()
228242

229243
# compare against PTQ ops
230-
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
244+
out_ptq = quantize_affine(
231245
x2,
246+
block_size,
232247
s,
233248
zp,
249+
torch.int8,
234250
qmin,
235251
qmax,
236-
torch.int8,
237252
)
238-
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
253+
out_ptq = dequantize_affine(
239254
out_ptq,
255+
block_size,
240256
s,
241257
zp,
258+
torch.int8,
242259
qmin,
243260
qmax,
244-
torch.int8,
245-
torch.float32,
261+
output_dtype=torch.float32,
246262
)
247263
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
248264

@@ -1004,8 +1020,15 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
10041020
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
10051021
"""
10061022
# activations
1007-
(s, zp) = _choose_qparams_per_token_asymmetric(
1008-
x, torch.float32, torch.int32
1023+
(s, zp) = choose_qparams_affine(
1024+
x,
1025+
mapping_type=MappingType.ASYMMETRIC,
1026+
block_size=_get_per_token_block_size(x),
1027+
target_dtype=torch.int8,
1028+
quant_min=-128,
1029+
quant_max=127,
1030+
scale_dtype=torch.float32,
1031+
zero_point_dtype=torch.int32,
10091032
)
10101033
(qmin, qmax) = _get_qmin_qmax(8)
10111034
x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax)
@@ -1427,10 +1450,11 @@ def test_qat_linear_bias(self):
14271450
example_inputs = m.example_inputs()
14281451
m(*example_inputs)
14291452

1453+
@parameterized.expand([torch.float32, torch.bfloat16, torch.float16])
14301454
@unittest.skipIf(
14311455
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
14321456
)
1433-
def test_fake_quantize_per_token_vs_convert(self):
1457+
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14341458
"""
14351459
Test that the following produce the exact same numerics:
14361460
1. FakeQuantizer with asymmetric per_token config
@@ -1439,17 +1463,18 @@ def test_fake_quantize_per_token_vs_convert(self):
14391463
from torchao.quantization.utils import per_token_dynamic_quant
14401464

14411465
torch.manual_seed(self.SEED)
1442-
x = torch.randn(1, 235, 2048)
1466+
x = torch.randn(1, 235, 2048).to(dtype)
14431467
config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14441468
fake_quantizer = FakeQuantizer(config)
14451469
fake_quantizer_out = fake_quantizer(x)
14461470
baseline_out = per_token_dynamic_quant(x)
14471471
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
14481472

1473+
@parameterized.expand([torch.float32, torch.bfloat16, torch.float16])
14491474
@unittest.skipIf(
14501475
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
14511476
)
1452-
def test_qat_8da4w_prepare_vs_convert(self):
1477+
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14531478
"""
14541479
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
14551480
numerics that match exactly over N trials.
@@ -1463,7 +1488,7 @@ def test_qat_8da4w_prepare_vs_convert(self):
14631488

14641489
for seed in range(self.SEED, self.SEED + num_trials):
14651490
torch.manual_seed(seed)
1466-
m = M4()
1491+
m = M4(dtype)
14671492
torch.manual_seed(seed)
14681493
x = m.example_inputs()
14691494

torchao/quantization/qat/fake_quantizer.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
from torchao.quantization.quant_primitives import (
1717
_DTYPE_TO_BIT_WIDTH,
1818
_DTYPE_TO_QVALUE_BOUNDS,
19+
MappingType,
20+
choose_qparams_affine,
1921
)
2022
from torchao.quantization.utils import (
23+
_get_per_token_block_size,
2124
get_group_qparams_symmetric,
2225
get_groupwise_affine_qparams,
2326
)
@@ -26,7 +29,6 @@
2629
FakeQuantizeConfig,
2730
)
2831
from .utils import (
29-
_choose_qparams_per_token_asymmetric,
3032
_fake_quantize_per_channel_group,
3133
_fake_quantize_per_token,
3234
)
@@ -69,13 +71,19 @@ def _per_token_forward(self, x: torch.Tensor):
6971
"""
7072
if self.config.is_symmetric:
7173
raise NotImplementedError("Symmetric per token is not supported yet")
74+
75+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7276
if self._should_compute_qparams():
73-
(self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric(
77+
self.scale, self.zero_point = choose_qparams_affine(
7478
x,
75-
self.config.scale_precision,
76-
self.config.zero_point_precision,
79+
mapping_type=MappingType.ASYMMETRIC,
80+
block_size=_get_per_token_block_size(x),
81+
target_dtype=self.config.dtype,
82+
quant_min=qmin,
83+
quant_max=qmax,
84+
scale_dtype=self.config.scale_precision,
85+
zero_point_dtype=self.config.zero_point_precision,
7786
)
78-
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7987
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
8088

8189
def _per_channel_or_group_forward(self, x: torch.Tensor):
@@ -100,6 +108,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
100108
raise ValueError("Unexpected granularity '%s'" % granularity)
101109

102110
# get scales and zero points
111+
# TODO: refactor this to use `choose_qparams_affine`
103112
if self._should_compute_qparams():
104113
bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype]
105114
if is_symmetric:

torchao/quantization/qat/utils.py

+2-52
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Tuple
7+
from typing import List
88

99
import torch
1010

@@ -126,9 +126,8 @@ def _fake_quantize_per_token(
126126

127127
_per_token_quant_qparam_dim_check(input, scales, zero_points)
128128
block_size = _get_per_token_block_size(input)
129-
fq_input = input.to(torch.float32)
130129
fq = _GenericFakeQuantize.apply(
131-
fq_input,
130+
input,
132131
block_size,
133132
scales,
134133
zero_points,
@@ -138,55 +137,6 @@ def _fake_quantize_per_token(
138137
return fq.reshape_as(input).to(input.dtype)
139138

140139

141-
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
142-
# The version in pytorch does not have backward support yet so we add
143-
# it here for now until https://github.com/pytorch/pytorch/pull/123452
144-
# is landed.
145-
def _choose_qparams_per_token_asymmetric(
146-
input: torch.Tensor,
147-
scales_precision: torch.dtype = torch.float32,
148-
zero_points_precision: torch.dtype = torch.float32,
149-
) -> Tuple[torch.Tensor, torch.Tensor]:
150-
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
151-
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
152-
every N elements with the same quantization parameter. The dimension for scales/zero_points
153-
will be (M1 * M2 ... * Mn)
154-
155-
Args:
156-
input (torch.Tensor): original float32/float16 Tensor
157-
scales_precision (torch.dtype): precision of returned scales
158-
zero_points_precision (torch.dtype): precision of returned zero points
159-
160-
Returns:
161-
scales and zero_points, both float32 Tensors
162-
"""
163-
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
164-
qmin, qmax = -128, 127
165-
min_val = torch.amin(input, dim=-1, keepdim=True)
166-
max_val = torch.amax(input, dim=-1, keepdim=True)
167-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
168-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
169-
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
170-
171-
# scale
172-
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
173-
scale = scale.clamp(min=eps)
174-
175-
# zero point
176-
descaled_min = min_val_neg / scale
177-
descaled_max = max_val_pos / scale
178-
zero_point_from_min_error = qmin + descaled_min
179-
zero_point_from_max_error = qmax + descaled_max
180-
zero_point = torch.where(
181-
zero_point_from_min_error + zero_point_from_max_error > 0,
182-
qmin - descaled_min,
183-
qmax - descaled_max,
184-
)
185-
zero_point = torch.clamp(zero_point, qmin, qmax).round()
186-
187-
return scale.to(scales_precision), zero_point.to(zero_points_precision)
188-
189-
190140
def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
191141
if symmetric:
192142
qmin = -(2 ** (n_bit - 1))

torchao/quantization/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def get_group_qparams_symmetric(
540540
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"
541541

542542
block_size = (1, groupsize)
543-
eps = torch.finfo(torch.float32).eps
543+
eps = torch.finfo(w.dtype).eps
544544
ranges = {}
545545
ranges[1] = (-1, 0)
546546
# generating ranges for bit 2 to 8

0 commit comments

Comments
 (0)