@@ -122,9 +122,11 @@ def test_ternarybit_lsbq_parq_prox(self):
122
122
123
123
124
124
class TestUnifTorchaoQuantizer (common_utils .TestCase ):
125
- def setUp (self ):
125
+ def __init__ (self , methodName , group_size : int = 32 ):
126
+ super (TestUnifTorchaoQuantizer , self ).__init__ (methodName )
126
127
torch .manual_seed (123 )
127
- self .group_size = 32
128
+ self .group_size = group_size
129
+ self .out_dim = 512
128
130
129
131
def compare_quantized_models (
130
132
self ,
@@ -152,14 +154,13 @@ def compare_quantized_models(
152
154
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
153
155
@unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
154
156
def test_int4_weight_only (self ):
155
- model = M (n = 1024 , k = 1024 ).to (torch .bfloat16 ).to (_DEVICE )
157
+ model = M (m = self . out_dim , n = self . out_dim ).to (torch .bfloat16 ).to (_DEVICE )
156
158
model .reset_parameters ()
157
159
m_ref = copy .deepcopy (model )
158
160
159
- config = Int4WeightOnlyConfig (group_size = self .group_size )
160
- quantize_ (m_ref , config , device = _DEVICE )
161
+ quantize_ (m_ref , Int4WeightOnlyConfig (group_size = self .group_size ))
161
162
162
- # copied from torchao.quantization.quant_api._int4_weight_only_transform
163
+ # based off torchao.quantization.quant_api._int4_weight_only_transform
163
164
b = 4
164
165
quantizer = UnifTorchaoQuantizer (
165
166
symmetric = False ,
@@ -169,20 +170,19 @@ def test_int4_weight_only(self):
169
170
eps = 1e-6 ,
170
171
preserve_zero = False ,
171
172
)
172
- self .assertTrue (
173
- quantizer .get_quant_size (b ) == quantizer .quant_max - quantizer .quant_min + 1
174
- )
175
173
self .compare_quantized_models (model , m_ref , quantizer , b )
176
174
177
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.4 +" )
175
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6 +" )
178
176
@unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
179
177
def test_intx_weight_only (self ):
180
- model = M (n = 512 , k = 512 ).to (_DEVICE )
178
+ model = M (m = self . out_dim , n = self . out_dim ).to (_DEVICE )
181
179
model .reset_parameters ()
182
180
m_ref = copy .deepcopy (model )
183
181
184
182
config = IntxWeightOnlyConfig (granularity = PerGroup (self .group_size ))
185
- quantize_ (m_ref , config , device = _DEVICE )
183
+ quantize_ (m_ref , config )
184
+
185
+ # based off torchao.quantization.quant_api._intx_weight_only_transform
186
186
b = 8
187
187
q_dtype = torch .int8
188
188
quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [q_dtype ]
@@ -195,11 +195,18 @@ def test_intx_weight_only(self):
195
195
preserve_zero = True ,
196
196
zero_point_domain = ZeroPointDomain .INT ,
197
197
)
198
- self .assertTrue (
199
- quantizer .get_quant_size (b ) == max (abs (quant_min ), quant_max ) + 1
200
- )
201
198
self .compare_quantized_models (model , m_ref , quantizer , b )
202
199
203
200
201
+ def load_tests (loader , tests , pattern ):
202
+ suite = unittest .TestSuite ()
203
+ suite .addTests (loader .loadTestsFromTestCase (TestPARQuantization ))
204
+ suite .addTests (loader .loadTestsFromTestCase (TestUnifTorchaoQuantizer ))
205
+
206
+ group_size = suite ._tests [- 1 ].out_dim # row-wise grouping
207
+ suite .addTest (TestUnifTorchaoQuantizer ("test_intx_weight_only" , group_size ))
208
+ return suite
209
+
210
+
204
211
if __name__ == "__main__" :
205
212
unittest .main ()
0 commit comments