4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import tempfile
7
8
import unittest
8
9
9
10
import torch
10
11
from torch .testing ._internal .common_utils import (
11
12
TestCase ,
13
+ instantiate_parametrized_tests ,
14
+ parametrize ,
12
15
run_tests ,
13
16
)
14
17
23
26
is_sm_at_least_90 ,
24
27
)
25
28
29
+ BF16_ACT_CONFIG = FbgemmConfig (
30
+ input_dtype = torch .bfloat16 ,
31
+ weight_dtype = torch .int4 ,
32
+ output_dtype = torch .bfloat16 ,
33
+ block_size = [1 , 128 ],
34
+ preshuffle = True ,
35
+ activation_dtype_for_int4 = "bf16" ,
36
+ )
37
+
38
+ BF16_ACT_BMM_CONFIG = FbgemmConfig (
39
+ input_dtype = torch .bfloat16 ,
40
+ weight_dtype = torch .int4 ,
41
+ output_dtype = torch .bfloat16 ,
42
+ block_size = [1 , 1 , 128 ],
43
+ preshuffle = True ,
44
+ activation_dtype_for_int4 = "bf16" ,
45
+ )
46
+
47
+ FP8_ACT_CONFIG = FbgemmConfig (
48
+ input_dtype = torch .bfloat16 ,
49
+ weight_dtype = torch .int4 ,
50
+ output_dtype = torch .bfloat16 ,
51
+ block_size = [1 , 128 ],
52
+ preshuffle = True ,
53
+ activation_dtype_for_int4 = "fp8" ,
54
+ )
55
+
56
+ FP8_ACT_BMM_CONFIG = FbgemmConfig (
57
+ input_dtype = torch .bfloat16 ,
58
+ weight_dtype = torch .int4 ,
59
+ output_dtype = torch .bfloat16 ,
60
+ block_size = [1 , 1 , 128 ],
61
+ preshuffle = True ,
62
+ activation_dtype_for_int4 = "fp8" ,
63
+ )
64
+
26
65
27
66
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
28
67
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
32
71
)
33
72
class TestInt4GroupwisePreshuffleTensor (TestCase ):
34
73
def setUp (self ):
35
- self .config = FbgemmConfig (
36
- input_dtype = torch .bfloat16 ,
37
- weight_dtype = torch .int4 ,
38
- output_dtype = torch .bfloat16 ,
39
- block_size = [1 , 128 ],
40
- preshuffle = True ,
41
- )
42
- self .bmm_config = FbgemmConfig (
43
- input_dtype = torch .bfloat16 ,
44
- weight_dtype = torch .int4 ,
45
- output_dtype = torch .bfloat16 ,
46
- block_size = [1 , 1 , 128 ],
47
- preshuffle = True ,
48
- )
49
74
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
50
75
51
- def test_linear (self ):
76
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
77
+ def test_linear (self , config ):
52
78
dtype = torch .bfloat16
53
79
device = "cuda"
54
80
input = torch .randn (1 , 128 , dtype = dtype , device = device )
55
81
linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
56
82
original = linear (input )
57
- quantize_ (linear , self . config )
83
+ quantize_ (linear , config )
58
84
quantized = linear (input )
59
85
self .assertTrue (compute_error (original , quantized ) > 20 )
60
86
61
- def test_bmm (self ):
87
+ # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
88
+ # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
89
+ @parametrize ("bmm_config" , [FP8_ACT_BMM_CONFIG , BF16_ACT_BMM_CONFIG ])
90
+ def test_bmm (self , bmm_config ):
62
91
class M (torch .nn .Module ):
63
92
def __init__ (self , weight ):
64
93
super ().__init__ ()
@@ -74,32 +103,46 @@ def forward(self, x):
74
103
m = M (weight ).eval ()
75
104
original = m (input )
76
105
m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
77
- quantize_ (m , self . bmm_config , filter_fn = lambda x , fqn : True )
106
+ quantize_ (m , bmm_config , filter_fn = lambda x , fqn : True )
78
107
quantized = m (input )
79
108
self .assertTrue (compute_error (original , quantized ) > 18 )
80
109
81
- def test_to_device (self ):
110
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
111
+ def test_to_device (self , config ):
82
112
for device in self .GPU_DEVICES :
83
113
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
84
- quantize_ (linear , self . config )
114
+ quantize_ (linear , config )
85
115
linear .to (device )
86
116
87
117
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88
- quantize_ (linear , self . config )
118
+ quantize_ (linear , config )
89
119
linear .to (device = device )
90
120
91
121
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
92
- quantize_ (linear , self . config )
122
+ quantize_ (linear , config )
93
123
linear .to (device )
94
124
95
- def test_module_path (self ):
125
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
126
+ def test_module_path (self , config ):
96
127
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
97
- quantize_ (linear , self . config )
128
+ quantize_ (linear , config )
98
129
self .assertEqual (
99
130
str (type (linear .weight )),
100
131
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>" ,
101
132
)
102
133
134
+ with tempfile .NamedTemporaryFile () as f :
135
+ torch .save (linear .state_dict (), f )
136
+ f .seek (0 )
137
+ state_dict = torch .load (f )
138
+ self .assertEqual (
139
+ str (type (state_dict ["weight" ])),
140
+ "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>" ,
141
+ )
142
+
143
+
144
+ instantiate_parametrized_tests (TestInt4GroupwisePreshuffleTensor )
145
+
103
146
104
147
if __name__ == "__main__" :
105
148
run_tests ()
0 commit comments