5
5
# LICENSE file in the root directory of this source tree.
6
6
import copy
7
7
import unittest
8
+ from typing import Optional
8
9
9
10
import torch
10
11
from torch import nn
11
12
from torch .testing ._internal import common_utils
12
13
13
- from torchao import quantize_
14
+ from torchao . core . config import AOBaseConfig
14
15
from torchao .dtypes import Int4CPULayout
15
16
from torchao .prototype .parq .optim import (
16
17
ProxHardQuant ,
29
30
IntxWeightOnlyConfig ,
30
31
_is_linear ,
31
32
int4_weight_only ,
33
+ quantize_ ,
32
34
)
33
35
from torchao .utils import (
34
36
TORCH_VERSION_AT_LEAST_2_4 ,
@@ -44,18 +46,26 @@ def split_param_groups(model):
44
46
45
47
def get_param_groups (model ):
46
48
for module in model .children ():
47
- is_linear = isinstance (module , nn . Linear )
48
- for n , p in module .named_parameters (recurse = False ):
49
+ is_linear = _is_linear (module )
50
+ for n , p in module .named_parameters ():
49
51
if is_linear and n == "weight" :
50
52
params_quant .append (p )
51
53
else :
52
54
params_no_quant .append (p )
53
- get_param_groups (module )
54
55
55
56
get_param_groups (model )
56
57
return params_quant , params_no_quant
57
58
58
59
60
+ def build_param_groups (model , b : int = 2 , group_size : Optional [int ] = None ):
61
+ params_quant , params_no_quant = split_param_groups (model )
62
+ quant_kwargs = {"quant_block_size" : group_size } if group_size else {}
63
+ return [
64
+ {"params" : params_quant , "quant_bits" : b , ** quant_kwargs },
65
+ {"params" : params_no_quant },
66
+ ]
67
+
68
+
59
69
class M (nn .Module ):
60
70
def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False ):
61
71
super ().__init__ ()
@@ -87,46 +97,28 @@ class TestPARQuantization(common_utils.TestCase):
87
97
def setUp (self ):
88
98
torch .manual_seed (123 )
89
99
self .model = M (bias = True ).to (_DEVICE )
90
- self .params_quant , self .params_no_quant = split_param_groups (self .model )
91
100
92
- def train_loop ( self , optimizer , steps = 1 ):
93
- for _ in range ( steps ):
94
- x = self . model . example_inputs ( device = _DEVICE )
95
- out = self . model ( x )
96
- out . sum (). backward ()
97
- optimizer . step ( )
101
+ @ common_utils . parametrize ( "b" , [ 0 , 1 , 2 , 4 ])
102
+ @ common_utils . parametrize ( "unif_quant" , [ True , False ])
103
+ @ common_utils . parametrize ( "hard_prox" , [ True , False ] )
104
+ def test_parq_train_loop ( self , b : int = 2 , unif_quant = True , hard_prox = True ):
105
+ if unif_quant and b == 0 :
106
+ self . skipTest ( "Ternary uniform quantization not yet supported" )
98
107
99
- def test_2bit_unif_quantizer_hard_prox (self ):
100
- b = 2
101
108
self .model .reset_parameters ()
102
- param_groups = [
103
- {"params" : self .params_quant , "quant_bits" : b },
104
- {"params" : self .params_no_quant },
105
- ]
109
+ param_groups = build_param_groups (self .model , b )
106
110
base_optimizer = torch .optim .AdamW (param_groups )
107
- quantizer = UnifQuantizer ()
108
- optimizer = QuantOptimizer (base_optimizer , quantizer , ProxHardQuant ())
109
- self .train_loop (optimizer )
110
111
111
- for child in self .model .children ():
112
- if isinstance (child , nn .Linear ):
113
- self .assertEqual (
114
- child .weight .unique ().numel (), quantizer .get_quant_size (b )
115
- )
116
-
117
- def test_ternarybit_lsbq_parq_prox (self ):
118
- b = 0
119
- self .model .reset_parameters ()
120
- param_groups = [
121
- {"params" : self .params_quant , "quant_bits" : b },
122
- {"params" : self .params_no_quant },
123
- ]
124
- base_optimizer = torch .optim .AdamW (param_groups )
125
- quantizer = LSBQuantizer ()
126
- optimizer = QuantOptimizer (
127
- base_optimizer , quantizer , ProxPARQ (anneal_start = 0 , anneal_end = 2 )
112
+ quantizer = UnifQuantizer () if unif_quant else LSBQuantizer ()
113
+ prox_map = (
114
+ ProxHardQuant () if hard_prox else ProxPARQ (anneal_start = 0 , anneal_end = 2 )
128
115
)
129
- self .train_loop (optimizer , steps = 3 )
116
+ optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map )
117
+ for _ in range (3 ):
118
+ x = self .model .example_inputs (device = _DEVICE )
119
+ out = self .model (x )
120
+ out .sum ().backward ()
121
+ optimizer .step ()
130
122
131
123
for child in self .model .children ():
132
124
if isinstance (child , nn .Linear ):
@@ -163,6 +155,35 @@ def compare_quantized_models(
163
155
ref = getattr (m_ref , n ).weight .dequantize ()
164
156
self .assertTrue (q .equal (ref ))
165
157
158
+ def compare_parq_convert (
159
+ self ,
160
+ model : nn .Module ,
161
+ m_ref : nn .Module ,
162
+ optimizer : QuantOptimizer ,
163
+ config : AOBaseConfig ,
164
+ ):
165
+ # do not update model weights, just quantize
166
+ optimizer .zero_grad ()
167
+ optimizer .step ()
168
+
169
+ orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
170
+
171
+ # equivalent to torchao's convert step
172
+ model .eval ()
173
+ optimizer .restore_latent_params ()
174
+ quantize_ (model , config , filter_fn = optimizer ._get_filter_fn (model ))
175
+
176
+ for n , module in model .named_modules ():
177
+ if not _is_linear (module ):
178
+ continue
179
+
180
+ p_orig = getattr (orig_model , n ).weight # PARQ weight
181
+ p = module .weight .dequantize () # PARQ weight after quantize_
182
+ p_ref = getattr (m_ref , n ).weight .dequantize () # native quantize_
183
+
184
+ self .assertTrue (p_orig .equal (p_ref ))
185
+ self .assertTrue (p .equal (p_ref ))
186
+
166
187
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
167
188
@common_utils .parametrize ("group_size" , [32 , 256 ])
168
189
def test_int4_weight_only (self , group_size : int = 32 ):
@@ -195,7 +216,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
195
216
),
196
217
)
197
218
198
- quantizer = UnifTorchaoQuantizer (symmetric = True )
219
+ quantizer = UnifTorchaoQuantizer ()
199
220
self .compare_quantized_models (model , m_ref , quantizer , b , group_size )
200
221
201
222
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
@@ -211,40 +232,36 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
211
232
quantize_ (m_ref , config )
212
233
213
234
b = 4
214
- params_quant , params_no_quant = split_param_groups (model )
215
- param_groups = [
216
- {"params" : params_quant , "quant_bits" : b , "quant_block_size" : group_size },
217
- {"params" : params_no_quant },
218
- ]
219
- base_optimizer = torch .optim .AdamW (param_groups )
220
-
235
+ base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
221
236
optimizer = QuantOptimizer (
222
237
base_optimizer ,
223
238
Int4UnifTorchaoQuantizer (),
224
239
ProxHardQuant (),
225
240
quant_per_channel = True ,
226
241
)
242
+ self .compare_parq_convert (model , m_ref , optimizer , config )
227
243
228
- # do not update model weights, just quantize
229
- optimizer .zero_grad ()
230
- optimizer .step ()
231
-
232
- orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
233
-
234
- # equivalent to torchao's convert step
235
- model .eval ()
236
- optimizer .torchao_quantize_ (model , config )
237
-
238
- for n , module in model .named_modules ():
239
- if not _is_linear (module ):
240
- continue
244
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
245
+ @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
246
+ @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
247
+ def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
248
+ model = M (m = 512 , n = 512 ).to (_DEVICE )
249
+ model .reset_parameters ()
241
250
242
- p_orig = getattr (orig_model , n ).weight # PARQ weight
243
- p = module .weight .dequantize () # PARQ weight after quantize_
244
- p_ref = getattr (m_ref , n ).weight .dequantize () # natively quantize_
251
+ m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
252
+ config = IntxWeightOnlyConfig (
253
+ weight_dtype = _BIT_WIDTH_TO_DTYPE [b ], granularity = PerGroup (group_size )
254
+ )
255
+ quantize_ (m_ref , config )
245
256
246
- self .assertTrue (p_orig .equal (p_ref ))
247
- self .assertTrue (p .equal (p_ref ))
257
+ base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
258
+ optimizer = QuantOptimizer (
259
+ base_optimizer ,
260
+ UnifTorchaoQuantizer (),
261
+ ProxHardQuant (),
262
+ quant_per_channel = True ,
263
+ )
264
+ self .compare_parq_convert (model , m_ref , optimizer , config )
248
265
249
266
250
267
common_utils .instantiate_parametrized_tests (TestPARQuantization )
0 commit comments