4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import tempfile
7
8
import unittest
8
9
9
10
import torch
10
11
from torch .testing ._internal .common_utils import (
11
12
TestCase ,
13
+ instantiate_parametrized_tests ,
14
+ parametrize ,
12
15
run_tests ,
13
16
)
14
17
22
25
_is_fbgemm_genai_gpu_available ,
23
26
is_sm_at_least_90 ,
24
27
)
28
+ from torchao .float8 .config import e4m3_dtype
29
+
30
+ BF16_ACT_CONFIG = FbgemmConfig (
31
+ input_dtype = torch .bfloat16 ,
32
+ weight_dtype = torch .int4 ,
33
+ output_dtype = torch .bfloat16 ,
34
+ block_size = [1 , 128 ],
35
+ preshuffle = True ,
36
+ )
37
+
38
+ BF16_ACT_BMM_CONFIG = FbgemmConfig (
39
+ input_dtype = torch .bfloat16 ,
40
+ weight_dtype = torch .int4 ,
41
+ output_dtype = torch .bfloat16 ,
42
+ block_size = [1 , 1 , 128 ],
43
+ preshuffle = True ,
44
+ )
45
+
46
+ FP8_ACT_CONFIG = FbgemmConfig (
47
+ input_dtype = e4m3_dtype ,
48
+ weight_dtype = torch .int4 ,
49
+ output_dtype = torch .bfloat16 ,
50
+ block_size = [1 , 128 ],
51
+ preshuffle = True ,
52
+ )
53
+
54
+ FP8_ACT_BMM_CONFIG = FbgemmConfig (
55
+ input_dtype = e4m3_dtype ,
56
+ weight_dtype = torch .int4 ,
57
+ output_dtype = torch .bfloat16 ,
58
+ block_size = [1 , 1 , 128 ],
59
+ preshuffle = True ,
60
+ )
25
61
26
62
27
63
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
32
68
)
33
69
class TestInt4GroupwisePreshuffleTensor (TestCase ):
34
70
def setUp (self ):
35
- self .config = FbgemmConfig (
36
- input_dtype = torch .bfloat16 ,
37
- weight_dtype = torch .int4 ,
38
- output_dtype = torch .bfloat16 ,
39
- block_size = [1 , 128 ],
40
- preshuffle = True ,
41
- )
42
- self .bmm_config = FbgemmConfig (
43
- input_dtype = torch .bfloat16 ,
44
- weight_dtype = torch .int4 ,
45
- output_dtype = torch .bfloat16 ,
46
- block_size = [1 , 1 , 128 ],
47
- preshuffle = True ,
48
- )
49
71
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
50
72
51
- def test_linear (self ):
73
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
74
+ def test_linear (self , config ):
52
75
dtype = torch .bfloat16
53
76
device = "cuda"
54
77
input = torch .randn (1 , 128 , dtype = dtype , device = device )
55
78
linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
56
79
original = linear (input )
57
- quantize_ (linear , self . config )
80
+ quantize_ (linear , config )
58
81
quantized = linear (input )
59
82
self .assertTrue (compute_error (original , quantized ) > 20 )
60
83
61
- def test_bmm (self ):
84
+ # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
85
+ # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
86
+ @parametrize ("bmm_config" , [FP8_ACT_BMM_CONFIG , BF16_ACT_BMM_CONFIG ])
87
+ def test_bmm (self , bmm_config ):
62
88
class M (torch .nn .Module ):
63
89
def __init__ (self , weight ):
64
90
super ().__init__ ()
@@ -74,32 +100,46 @@ def forward(self, x):
74
100
m = M (weight ).eval ()
75
101
original = m (input )
76
102
m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
77
- quantize_ (m , self . bmm_config , filter_fn = lambda x , fqn : True )
103
+ quantize_ (m , bmm_config , filter_fn = lambda x , fqn : True )
78
104
quantized = m (input )
79
105
self .assertTrue (compute_error (original , quantized ) > 18 )
80
106
81
- def test_to_device (self ):
107
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
108
+ def test_to_device (self , config ):
82
109
for device in self .GPU_DEVICES :
83
110
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
84
- quantize_ (linear , self . config )
111
+ quantize_ (linear , config )
85
112
linear .to (device )
86
113
87
114
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88
- quantize_ (linear , self . config )
115
+ quantize_ (linear , config )
89
116
linear .to (device = device )
90
117
91
118
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
92
- quantize_ (linear , self . config )
119
+ quantize_ (linear , config )
93
120
linear .to (device )
94
121
95
- def test_module_path (self ):
122
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
123
+ def test_module_path (self , config ):
96
124
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
97
- quantize_ (linear , self . config )
125
+ quantize_ (linear , config )
98
126
self .assertEqual (
99
127
str (type (linear .weight )),
100
128
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>" ,
101
129
)
102
130
131
+ with tempfile .NamedTemporaryFile () as f :
132
+ torch .save (linear .state_dict (), f )
133
+ f .seek (0 )
134
+ state_dict = torch .load (f )
135
+ self .assertEqual (
136
+ str (type (state_dict ["weight" ])),
137
+ "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>" ,
138
+ )
139
+
140
+
141
+ instantiate_parametrized_tests (TestInt4GroupwisePreshuffleTensor )
142
+
103
143
104
144
if __name__ == "__main__" :
105
145
run_tests ()
0 commit comments