12
12
13
13
import torch
14
14
import torch .nn .functional as F
15
+ from parameterized import parameterized
15
16
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
16
17
17
18
from torchao import quantize_
40
41
Int8DynActInt4WeightQATLinear ,
41
42
)
42
43
from torchao .quantization .qat .utils import (
43
- _choose_qparams_per_token_asymmetric ,
44
44
_fake_quantize_per_channel_group ,
45
45
_fake_quantize_per_token ,
46
46
_GenericFakeQuantize ,
53
53
MappingType ,
54
54
TorchAODType ,
55
55
ZeroPointDomain ,
56
+ choose_qparams_affine ,
57
+ dequantize_affine ,
56
58
fake_quantize_affine ,
59
+ quantize_affine ,
57
60
)
58
61
from torchao .quantization .unified import (
59
62
TwoStepQuantizer ,
60
63
)
61
64
from torchao .quantization .utils import (
65
+ _get_per_token_block_size ,
62
66
get_group_qparams_symmetric ,
63
67
get_groupwise_affine_qparams ,
64
68
groupwise_affine_quantize_tensor ,
@@ -134,12 +138,13 @@ def forward(self, x):
134
138
135
139
136
140
class M4 (torch .nn .Module ):
137
- def __init__ (self ):
141
+ def __init__ (self , dtype : torch . dtype = torch . float32 ):
138
142
super ().__init__ ()
139
- self .linear = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
143
+ self .dtype = dtype
144
+ self .linear = torch .nn .Linear (512 , 256 , bias = False ).to (dtype )
140
145
141
146
def example_inputs (self ):
142
- return (torch .randn (1 , 512 ).to (torch . float ),)
147
+ return (torch .randn (1 , 512 ).to (self . dtype ),)
143
148
144
149
def forward (self , x ):
145
150
return self .linear (x )
@@ -219,30 +224,41 @@ def test_fake_quantize_per_token(self):
219
224
torch .manual_seed (self .SEED )
220
225
x = torch .randn (100 , 256 ).requires_grad_ ()
221
226
x2 = copy .deepcopy (x )
222
- # TODO: use torch.ops.aten.quantized_decomposed version instead
223
- (s , zp ) = _choose_qparams_per_token_asymmetric (x , torch .float32 , torch .int32 )
227
+ block_size = _get_per_token_block_size (x )
228
+ (s , zp ) = choose_qparams_affine (
229
+ x ,
230
+ mapping_type = MappingType .ASYMMETRIC ,
231
+ block_size = block_size ,
232
+ target_dtype = torch .int8 ,
233
+ quant_min = - 128 ,
234
+ quant_max = 127 ,
235
+ scale_dtype = torch .float32 ,
236
+ zero_point_dtype = torch .int32 ,
237
+ )
224
238
225
239
# fake quant op
226
240
out = _fake_quantize_per_token (x , s , zp , qmin , qmax )
227
241
out .sum ().backward ()
228
242
229
243
# compare against PTQ ops
230
- out_ptq = torch . ops . quantized_decomposed . quantize_per_token (
244
+ out_ptq = quantize_affine (
231
245
x2 ,
246
+ block_size ,
232
247
s ,
233
248
zp ,
249
+ torch .int8 ,
234
250
qmin ,
235
251
qmax ,
236
- torch .int8 ,
237
252
)
238
- out_ptq = torch . ops . quantized_decomposed . dequantize_per_token (
253
+ out_ptq = dequantize_affine (
239
254
out_ptq ,
255
+ block_size ,
240
256
s ,
241
257
zp ,
258
+ torch .int8 ,
242
259
qmin ,
243
260
qmax ,
244
- torch .int8 ,
245
- torch .float32 ,
261
+ output_dtype = torch .float32 ,
246
262
)
247
263
torch .testing .assert_close (out , out_ptq , atol = 0 , rtol = 0 )
248
264
@@ -1004,8 +1020,15 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1004
1020
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
1005
1021
"""
1006
1022
# activations
1007
- (s , zp ) = _choose_qparams_per_token_asymmetric (
1008
- x , torch .float32 , torch .int32
1023
+ (s , zp ) = choose_qparams_affine (
1024
+ x ,
1025
+ mapping_type = MappingType .ASYMMETRIC ,
1026
+ block_size = _get_per_token_block_size (x ),
1027
+ target_dtype = torch .int8 ,
1028
+ quant_min = - 128 ,
1029
+ quant_max = 127 ,
1030
+ scale_dtype = torch .float32 ,
1031
+ zero_point_dtype = torch .int32 ,
1009
1032
)
1010
1033
(qmin , qmax ) = _get_qmin_qmax (8 )
1011
1034
x_fq = _fake_quantize_per_token (x , s , zp , qmin , qmax )
@@ -1427,10 +1450,11 @@ def test_qat_linear_bias(self):
1427
1450
example_inputs = m .example_inputs ()
1428
1451
m (* example_inputs )
1429
1452
1453
+ @parameterized .expand ([torch .float32 , torch .bfloat16 , torch .float16 ])
1430
1454
@unittest .skipIf (
1431
1455
not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1432
1456
)
1433
- def test_fake_quantize_per_token_vs_convert (self ):
1457
+ def test_fake_quantize_per_token_vs_convert (self , dtype : torch . dtype ):
1434
1458
"""
1435
1459
Test that the following produce the exact same numerics:
1436
1460
1. FakeQuantizer with asymmetric per_token config
@@ -1439,17 +1463,18 @@ def test_fake_quantize_per_token_vs_convert(self):
1439
1463
from torchao .quantization .utils import per_token_dynamic_quant
1440
1464
1441
1465
torch .manual_seed (self .SEED )
1442
- x = torch .randn (1 , 235 , 2048 )
1466
+ x = torch .randn (1 , 235 , 2048 ). to ( dtype )
1443
1467
config = FakeQuantizeConfig (torch .int8 , "per_token" , is_symmetric = False )
1444
1468
fake_quantizer = FakeQuantizer (config )
1445
1469
fake_quantizer_out = fake_quantizer (x )
1446
1470
baseline_out = per_token_dynamic_quant (x )
1447
1471
torch .testing .assert_close (fake_quantizer_out , baseline_out , atol = 0 , rtol = 0 )
1448
1472
1473
+ @parameterized .expand ([torch .float32 , torch .bfloat16 , torch .float16 ])
1449
1474
@unittest .skipIf (
1450
1475
not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1451
1476
)
1452
- def test_qat_8da4w_prepare_vs_convert (self ):
1477
+ def test_qat_8da4w_prepare_vs_convert (self , dtype : torch . dtype ):
1453
1478
"""
1454
1479
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
1455
1480
numerics that match exactly over N trials.
@@ -1463,7 +1488,7 @@ def test_qat_8da4w_prepare_vs_convert(self):
1463
1488
1464
1489
for seed in range (self .SEED , self .SEED + num_trials ):
1465
1490
torch .manual_seed (seed )
1466
- m = M4 ()
1491
+ m = M4 (dtype )
1467
1492
torch .manual_seed (seed )
1468
1493
x = m .example_inputs ()
1469
1494
0 commit comments