11
11
from torch .testing ._internal import common_utils
12
12
13
13
from torchao import quantize_
14
+ from torchao .dtypes import Int4CPULayout , TensorCoreTiledLayout
14
15
from torchao .prototype .parq .optim import (
15
16
ProxHardQuant ,
16
17
ProxPARQ ,
24
25
from torchao .prototype .parq .quant .uniform_torchao import _BIT_WIDTH_TO_DTYPE
25
26
from torchao .quantization .granularity import PerGroup
26
27
from torchao .quantization .quant_api import (
27
- Int4WeightOnlyConfig ,
28
+ LAYOUT_TO_ZERO_POINT_DOMAIN ,
28
29
IntxWeightOnlyConfig ,
29
30
_is_linear ,
31
+ int4_weight_only ,
30
32
)
31
33
from torchao .quantization .quant_primitives import ZeroPointDomain
32
34
from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , TORCH_VERSION_AT_LEAST_2_6
@@ -66,8 +68,8 @@ def reset_parameters(self):
66
68
if module .bias is not None :
67
69
nn .init .zeros_ (module .bias )
68
70
69
- def example_inputs (self ):
70
- return torch .randint (1 , 10 , (1 , 256 ))
71
+ def example_inputs (self , device = None ):
72
+ return torch .randint (1 , 10 , (1 , 256 ), device = device )
71
73
72
74
def forward (self , x ):
73
75
x = self .embedding (x )
@@ -78,20 +80,19 @@ def forward(self, x):
78
80
return x
79
81
80
82
81
- def train_loop (model , optimizer , update = True , steps = 1 ):
82
- for _ in range (steps ):
83
- x = model .example_inputs ().to (_DEVICE )
84
- out = model (x )
85
- out .sum ().backward ()
86
- optimizer .step ()
87
-
88
-
89
83
class TestPARQuantization (common_utils .TestCase ):
90
84
def setUp (self ):
91
85
torch .manual_seed (123 )
92
86
self .model = M (bias = True ).to (_DEVICE )
93
87
self .params_quant , self .params_no_quant = split_param_groups (self .model )
94
88
89
+ def train_loop (self , optimizer , steps = 1 ):
90
+ for _ in range (steps ):
91
+ x = self .model .example_inputs (device = _DEVICE )
92
+ out = self .model (x )
93
+ out .sum ().backward ()
94
+ optimizer .step ()
95
+
95
96
def test_2bit_unif_quantizer_hard_prox (self ):
96
97
b = 2
97
98
self .model .reset_parameters ()
@@ -102,7 +103,7 @@ def test_2bit_unif_quantizer_hard_prox(self):
102
103
base_optimizer = torch .optim .AdamW (param_groups )
103
104
quantizer = UnifQuantizer ()
104
105
optimizer = QuantOptimizer (base_optimizer , quantizer , ProxHardQuant ())
105
- train_loop ( self .model , optimizer )
106
+ self .train_loop ( optimizer )
106
107
107
108
for child in self .model .children ():
108
109
if isinstance (child , nn .Linear ):
@@ -122,7 +123,7 @@ def test_ternarybit_lsbq_parq_prox(self):
122
123
optimizer = QuantOptimizer (
123
124
base_optimizer , quantizer , ProxPARQ (anneal_start = 0 , anneal_end = 2 )
124
125
)
125
- train_loop ( self .model , optimizer , steps = 3 )
126
+ self .train_loop ( optimizer , steps = 3 )
126
127
127
128
for child in self .model .children ():
128
129
if isinstance (child , nn .Linear ):
@@ -136,7 +137,7 @@ def setUp(self):
136
137
torch .manual_seed (123 )
137
138
138
139
@staticmethod
139
- def int4_torchao_quantizer (b : int = 4 , config = None ):
140
+ def int4_torchao_quantizer (config , b : int = 4 ):
140
141
# based off torchao.quantization.quant_api._int4_weight_only_transform
141
142
return UnifTorchaoQuantizer (
142
143
symmetric = False ,
@@ -145,6 +146,7 @@ def int4_torchao_quantizer(b: int = 4, config=None):
145
146
quant_max = 2 ** b - 1 ,
146
147
eps = 1e-6 ,
147
148
preserve_zero = False ,
149
+ zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN [type (config .layout )][0 ],
148
150
config = config ,
149
151
)
150
152
@@ -172,28 +174,31 @@ def compare_quantized_models(
172
174
ref = getattr (m_ref , n ).weight .dequantize ()
173
175
self .assertTrue (q .equal (ref ))
174
176
175
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
176
- @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
177
177
@common_utils .parametrize ("group_size" , [32 , 256 ])
178
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
178
179
def test_int4_weight_only (self , group_size : int = 32 ):
179
180
model = M (m = 512 , n = 512 ).to (torch .bfloat16 ).to (_DEVICE )
180
181
model .reset_parameters ()
181
182
182
- m_ref = copy .deepcopy (model )
183
- quantize_ (m_ref , Int4WeightOnlyConfig (group_size ))
183
+ m_ref = copy .deepcopy (model ).eval ()
184
+ config = int4_weight_only (
185
+ group_size = group_size ,
186
+ layout = Int4CPULayout () if _DEVICE == "cpu" else TensorCoreTiledLayout (8 ),
187
+ )
188
+ quantize_ (m_ref , config )
184
189
185
190
b = 4
186
- quantizer = self .int4_torchao_quantizer ()
191
+ quantizer = self .int4_torchao_quantizer (config )
187
192
self .compare_quantized_models (model , m_ref , quantizer , b , group_size )
188
193
189
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
190
194
@common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
191
195
@common_utils .parametrize ("group_size" , [32 , 512 ])
196
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
192
197
def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
193
198
model = M (m = 512 , n = 512 ).to (_DEVICE )
194
199
model .reset_parameters ()
195
200
196
- m_ref = copy .deepcopy (model )
201
+ m_ref = copy .deepcopy (model ). eval ()
197
202
quantize_ (
198
203
m_ref ,
199
204
IntxWeightOnlyConfig (
@@ -214,8 +219,11 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
214
219
model = M (m = 512 , n = 512 ).to (torch .bfloat16 ).to (_DEVICE )
215
220
model .reset_parameters ()
216
221
217
- m_ref = copy .deepcopy (model )
218
- config = Int4WeightOnlyConfig (group_size )
222
+ m_ref = copy .deepcopy (model ).eval ()
223
+ config = int4_weight_only (
224
+ group_size = group_size ,
225
+ layout = TensorCoreTiledLayout (8 ),
226
+ )
219
227
quantize_ (m_ref , config )
220
228
221
229
b = 4
@@ -226,7 +234,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
226
234
]
227
235
base_optimizer = torch .optim .AdamW (param_groups )
228
236
229
- quantizer = self .int4_torchao_quantizer (config = config )
237
+ quantizer = self .int4_torchao_quantizer (config )
230
238
optimizer = QuantOptimizer (
231
239
base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
232
240
)
@@ -238,6 +246,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
238
246
orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
239
247
240
248
# equivalent to torchao's convert step
249
+ model .eval ()
241
250
with torch .no_grad ():
242
251
optimizer .restore_latent_params ()
243
252
quantize_ (model , quantizer .config )
0 commit comments