77import unittest
88
99import torch
10- from torch .testing ._internal .common_utils import TestCase
10+ from torch import nn
11+ from torch .testing ._internal import common_utils
1112
1213from torchao import quantize_
1314from torchao .prototype .parq .optim import (
2021 UnifQuantizer ,
2122 UnifTorchaoQuantizer ,
2223)
23- from torchao .quantization .quant_api import int4_weight_only
24- from torchao .utils import TORCH_VERSION_AT_LEAST_2_4
24+ from torchao .quantization .granularity import PerGroup
25+ from torchao .quantization .quant_api import (
26+ Int4WeightOnlyConfig ,
27+ IntxWeightOnlyConfig ,
28+ _is_linear ,
29+ )
30+ from torchao .quantization .quant_primitives import (
31+ _DTYPE_TO_QVALUE_BOUNDS ,
32+ ZeroPointDomain ,
33+ )
34+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , TORCH_VERSION_AT_LEAST_2_6
2535
2636_DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
2737
@@ -36,20 +46,20 @@ def split_param_groups(model):
3646 return params_no_quant , params_quant
3747
3848
39- class M (torch . nn .Module ):
49+ class M (nn .Module ):
4050 def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False ):
4151 super ().__init__ ()
42- self .embedding = torch . nn .Embedding (10 , m )
43- self .linear1 = torch . nn .Linear (m , n , bias = bias )
44- self .linear2 = torch . nn .Linear (n , k , bias = bias )
45- self .relu = torch . nn .ReLU ()
46- self .sigmoid = torch . nn .Sigmoid ()
52+ self .embedding = nn .Embedding (10 , m )
53+ self .linear1 = nn .Linear (m , n , bias = bias )
54+ self .linear2 = nn .Linear (n , k , bias = bias )
55+ self .relu = nn .ReLU ()
56+ self .sigmoid = nn .Sigmoid ()
4757
4858 def reset_parameters (self ):
4959 for module in (self .linear1 , self .linear2 ):
50- torch . nn .init .xavier_uniform_ (module .weight )
60+ nn .init .xavier_uniform_ (module .weight )
5161 if module .bias is not None :
52- torch . nn .init .zeros_ (module .bias )
62+ nn .init .zeros_ (module .bias )
5363
5464 def example_inputs (self ):
5565 return torch .randint (1 , 10 , (1 , 256 ))
@@ -63,7 +73,7 @@ def forward(self, x):
6373 return x
6474
6575
66- class TestPARQuantization (TestCase ):
76+ class TestPARQuantization (common_utils . TestCase ):
6777 def setUp (self ):
6878 torch .manual_seed (123 )
6979 self .model = M (bias = True ).to (_DEVICE )
@@ -86,7 +96,7 @@ def test_2bit_unif_quantizer_hard_prox(self):
8696 optimizer .step ()
8797
8898 for child in self .model .children ():
89- if isinstance (child , torch . nn .Linear ):
99+ if isinstance (child , nn .Linear ):
90100 self .assertEqual (child .weight .unique ().numel (), 4 )
91101
92102 def test_ternarybit_lsbq_parq_prox (self ):
@@ -107,22 +117,47 @@ def test_ternarybit_lsbq_parq_prox(self):
107117 optimizer .step ()
108118
109119 for child in self .model .children ():
110- if isinstance (child , torch . nn .Linear ):
120+ if isinstance (child , nn .Linear ):
111121 self .assertEqual (child .weight .unique ().numel (), 3 )
112122
113123
114- class TestUnifTorchaoQuantizer (TestCase ):
115- def setUp (self , group_size = 32 ):
124+ class TestUnifTorchaoQuantizer (common_utils . TestCase ):
125+ def setUp (self ):
116126 torch .manual_seed (123 )
117- self .model = M (n = 1024 , k = 1024 ).to (torch .bfloat16 ).to (_DEVICE )
118- self .group_size = group_size
127+ self .group_size = 32
128+
129+ def compare_quantized_models (
130+ self ,
131+ model : nn .Module ,
132+ m_ref : nn .Module ,
133+ quantizer : UnifTorchaoQuantizer ,
134+ b : int ,
135+ ):
136+ for n , module in model .named_children ():
137+ if not _is_linear (module ):
138+ continue
139+
140+ # simulate grouping from QuantOptimizer.step
141+ p = module .weight
142+ original_shape = p .shape
143+ p = p .view (- 1 , self .group_size )
144+
145+ q , Q = quantizer .quantize (p , b = b , dim = - 1 )
146+ q = q .view (original_shape )
147+
148+ # compare to AffineQuantizedTensor instance
149+ ref = getattr (m_ref , n ).weight .dequantize ()
150+ self .assertTrue (q .equal (ref ))
119151
120152 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
121153 @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
122154 def test_int4_weight_only (self ):
123- self .model .reset_parameters ()
124- m_copy = copy .deepcopy (self .model )
125- quantize_ (m_copy , int4_weight_only (group_size = self .group_size ))
155+ model = M (n = 1024 , k = 1024 ).to (torch .bfloat16 ).to (_DEVICE )
156+ model .reset_parameters ()
157+ m_ref = copy .deepcopy (model )
158+
159+ config = Int4WeightOnlyConfig (group_size = self .group_size )
160+ quantize_ (m_ref , config , device = _DEVICE )
126161
127162 # copied from torchao.quantization.quant_api._int4_weight_only_transform
128163 b = 4
@@ -137,22 +172,33 @@ def test_int4_weight_only(self):
137172 self .assertTrue (
138173 quantizer .get_quant_size (b ) == quantizer .quant_max - quantizer .quant_min + 1
139174 )
175+ self .compare_quantized_models (model , m_ref , quantizer , b )
140176
141- for n , module in self .model .named_children ():
142- if not isinstance (module , torch .nn .Linear ):
143- continue
144-
145- # simulate grouping from QuantOptimizer.step
146- p = module .weight
147- original_shape = p .shape
148- p = p .view (- 1 , self .group_size )
149-
150- q , Q = quantizer .quantize (p , b = b , dim = - 1 )
151- q = q .view (original_shape )
152-
153- # compare to AffineQuantizedTensor instance
154- ref = getattr (m_copy , n ).weight .dequantize ()
155- self .assertTrue (q .equal (ref ))
177+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.4+" )
178+ @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
179+ def test_intx_weight_only (self ):
180+ model = M (n = 512 , k = 512 ).to (_DEVICE )
181+ model .reset_parameters ()
182+ m_ref = copy .deepcopy (model )
183+
184+ config = IntxWeightOnlyConfig (granularity = PerGroup (self .group_size ))
185+ quantize_ (m_ref , config , device = _DEVICE )
186+ b = 8
187+ q_dtype = torch .int8
188+ quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [q_dtype ]
189+ quantizer = UnifTorchaoQuantizer (
190+ symmetric = True ,
191+ target_dtype = q_dtype ,
192+ quant_min = quant_min ,
193+ quant_max = quant_max ,
194+ eps = torch .finfo (torch .float32 ).eps ,
195+ preserve_zero = True ,
196+ zero_point_domain = ZeroPointDomain .INT ,
197+ )
198+ self .assertTrue (
199+ quantizer .get_quant_size (b ) == max (abs (quant_min ), quant_max ) + 1
200+ )
201+ self .compare_quantized_models (model , m_ref , quantizer , b )
156202
157203
158204if __name__ == "__main__" :
0 commit comments