7
7
import unittest
8
8
9
9
import torch
10
- from torch .testing ._internal .common_utils import TestCase
10
+ from torch import nn
11
+ from torch .testing ._internal import common_utils
11
12
12
13
from torchao import quantize_
13
14
from torchao .prototype .parq .optim import (
20
21
UnifQuantizer ,
21
22
UnifTorchaoQuantizer ,
22
23
)
23
- from torchao .quantization .quant_api import int4_weight_only
24
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_4
24
+ from torchao .quantization .granularity import PerGroup
25
+ from torchao .quantization .quant_api import (
26
+ Int4WeightOnlyConfig ,
27
+ IntxWeightOnlyConfig ,
28
+ _is_linear ,
29
+ )
30
+ from torchao .quantization .quant_primitives import (
31
+ _DTYPE_TO_QVALUE_BOUNDS ,
32
+ ZeroPointDomain ,
33
+ )
34
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , TORCH_VERSION_AT_LEAST_2_6
25
35
26
36
_DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
27
37
@@ -36,20 +46,20 @@ def split_param_groups(model):
36
46
return params_no_quant , params_quant
37
47
38
48
39
- class M (torch . nn .Module ):
49
+ class M (nn .Module ):
40
50
def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False ):
41
51
super ().__init__ ()
42
- self .embedding = torch . nn .Embedding (10 , m )
43
- self .linear1 = torch . nn .Linear (m , n , bias = bias )
44
- self .linear2 = torch . nn .Linear (n , k , bias = bias )
45
- self .relu = torch . nn .ReLU ()
46
- self .sigmoid = torch . nn .Sigmoid ()
52
+ self .embedding = nn .Embedding (10 , m )
53
+ self .linear1 = nn .Linear (m , n , bias = bias )
54
+ self .linear2 = nn .Linear (n , k , bias = bias )
55
+ self .relu = nn .ReLU ()
56
+ self .sigmoid = nn .Sigmoid ()
47
57
48
58
def reset_parameters (self ):
49
59
for module in (self .linear1 , self .linear2 ):
50
- torch . nn .init .xavier_uniform_ (module .weight )
60
+ nn .init .xavier_uniform_ (module .weight )
51
61
if module .bias is not None :
52
- torch . nn .init .zeros_ (module .bias )
62
+ nn .init .zeros_ (module .bias )
53
63
54
64
def example_inputs (self ):
55
65
return torch .randint (1 , 10 , (1 , 256 ))
@@ -63,7 +73,7 @@ def forward(self, x):
63
73
return x
64
74
65
75
66
- class TestPARQuantization (TestCase ):
76
+ class TestPARQuantization (common_utils . TestCase ):
67
77
def setUp (self ):
68
78
torch .manual_seed (123 )
69
79
self .model = M (bias = True ).to (_DEVICE )
@@ -86,7 +96,7 @@ def test_2bit_unif_quantizer_hard_prox(self):
86
96
optimizer .step ()
87
97
88
98
for child in self .model .children ():
89
- if isinstance (child , torch . nn .Linear ):
99
+ if isinstance (child , nn .Linear ):
90
100
self .assertEqual (child .weight .unique ().numel (), 4 )
91
101
92
102
def test_ternarybit_lsbq_parq_prox (self ):
@@ -107,22 +117,47 @@ def test_ternarybit_lsbq_parq_prox(self):
107
117
optimizer .step ()
108
118
109
119
for child in self .model .children ():
110
- if isinstance (child , torch . nn .Linear ):
120
+ if isinstance (child , nn .Linear ):
111
121
self .assertEqual (child .weight .unique ().numel (), 3 )
112
122
113
123
114
- class TestUnifTorchaoQuantizer (TestCase ):
115
- def setUp (self , group_size = 32 ):
124
+ class TestUnifTorchaoQuantizer (common_utils . TestCase ):
125
+ def setUp (self ):
116
126
torch .manual_seed (123 )
117
- self .model = M (n = 1024 , k = 1024 ).to (torch .bfloat16 ).to (_DEVICE )
118
- self .group_size = group_size
127
+ self .group_size = 32
128
+
129
+ def compare_quantized_models (
130
+ self ,
131
+ model : nn .Module ,
132
+ m_ref : nn .Module ,
133
+ quantizer : UnifTorchaoQuantizer ,
134
+ b : int ,
135
+ ):
136
+ for n , module in model .named_children ():
137
+ if not _is_linear (module ):
138
+ continue
139
+
140
+ # simulate grouping from QuantOptimizer.step
141
+ p = module .weight
142
+ original_shape = p .shape
143
+ p = p .view (- 1 , self .group_size )
144
+
145
+ q , Q = quantizer .quantize (p , b = b , dim = - 1 )
146
+ q = q .view (original_shape )
147
+
148
+ # compare to AffineQuantizedTensor instance
149
+ ref = getattr (m_ref , n ).weight .dequantize ()
150
+ self .assertTrue (q .equal (ref ))
119
151
120
152
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
121
153
@unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
122
154
def test_int4_weight_only (self ):
123
- self .model .reset_parameters ()
124
- m_copy = copy .deepcopy (self .model )
125
- quantize_ (m_copy , int4_weight_only (group_size = self .group_size ))
155
+ model = M (n = 1024 , k = 1024 ).to (torch .bfloat16 ).to (_DEVICE )
156
+ model .reset_parameters ()
157
+ m_ref = copy .deepcopy (model )
158
+
159
+ config = Int4WeightOnlyConfig (group_size = self .group_size )
160
+ quantize_ (m_ref , config , device = _DEVICE )
126
161
127
162
# copied from torchao.quantization.quant_api._int4_weight_only_transform
128
163
b = 4
@@ -137,22 +172,33 @@ def test_int4_weight_only(self):
137
172
self .assertTrue (
138
173
quantizer .get_quant_size (b ) == quantizer .quant_max - quantizer .quant_min + 1
139
174
)
175
+ self .compare_quantized_models (model , m_ref , quantizer , b )
140
176
141
- for n , module in self .model .named_children ():
142
- if not isinstance (module , torch .nn .Linear ):
143
- continue
144
-
145
- # simulate grouping from QuantOptimizer.step
146
- p = module .weight
147
- original_shape = p .shape
148
- p = p .view (- 1 , self .group_size )
149
-
150
- q , Q = quantizer .quantize (p , b = b , dim = - 1 )
151
- q = q .view (original_shape )
152
-
153
- # compare to AffineQuantizedTensor instance
154
- ref = getattr (m_copy , n ).weight .dequantize ()
155
- self .assertTrue (q .equal (ref ))
177
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.4+" )
178
+ @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
179
+ def test_intx_weight_only (self ):
180
+ model = M (n = 512 , k = 512 ).to (_DEVICE )
181
+ model .reset_parameters ()
182
+ m_ref = copy .deepcopy (model )
183
+
184
+ config = IntxWeightOnlyConfig (granularity = PerGroup (self .group_size ))
185
+ quantize_ (m_ref , config , device = _DEVICE )
186
+ b = 8
187
+ q_dtype = torch .int8
188
+ quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [q_dtype ]
189
+ quantizer = UnifTorchaoQuantizer (
190
+ symmetric = True ,
191
+ target_dtype = q_dtype ,
192
+ quant_min = quant_min ,
193
+ quant_max = quant_max ,
194
+ eps = torch .finfo (torch .float32 ).eps ,
195
+ preserve_zero = True ,
196
+ zero_point_domain = ZeroPointDomain .INT ,
197
+ )
198
+ self .assertTrue (
199
+ quantizer .get_quant_size (b ) == max (abs (quant_min ), quant_max ) + 1
200
+ )
201
+ self .compare_quantized_models (model , m_ref , quantizer , b )
156
202
157
203
158
204
if __name__ == "__main__" :
0 commit comments