1
+ # Copyright (c) 2024 Intel Corporation
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch .nn as nn
20
+ import transformers
21
+
22
+ class TritonModuleMixin :
23
+ @classmethod
24
+ def warmup (cls , model , transpose = False , seqlen = 2048 ):
25
+ pass
26
+
27
+
28
+ class QuantLinear (nn .Module , TritonModuleMixin ):
29
+ QUANT_TYPE = "triton"
30
+
31
+ def __init__ (self , bits , group_size , infeatures , outfeatures , bias , trainable = False , ** kwargs ):
32
+ super ().__init__ ()
33
+ if bits not in [2 , 4 , 8 ]:
34
+ raise NotImplementedError ("Only 2,4,8 bits are supported." )
35
+ if infeatures % 32 != 0 or outfeatures % 32 != 0 :
36
+ raise NotImplementedError ("in_feature and out_feature must be divisible by 32." )
37
+ self .infeatures = infeatures
38
+ self .outfeatures = outfeatures
39
+ self .bits = bits
40
+ self .group_size = group_size if group_size != - 1 else infeatures
41
+ self .maxq = 2 ** self .bits - 1
42
+
43
+ self .register_buffer (
44
+ "qweight" ,
45
+ torch .zeros ((infeatures // 32 * self .bits , outfeatures ), dtype = torch .int32 ),
46
+ )
47
+ self .register_buffer (
48
+ "qzeros" ,
49
+ torch .zeros (
50
+ (
51
+ math .ceil (infeatures / self .group_size ),
52
+ outfeatures // 32 * self .bits ,
53
+ ),
54
+ dtype = torch .int32 ,
55
+ ),
56
+ )
57
+ self .register_buffer (
58
+ "scales" ,
59
+ torch .zeros (
60
+ (math .ceil (infeatures / self .group_size ), outfeatures ),
61
+ dtype = torch .float16 ,
62
+ ),
63
+ )
64
+
65
+ if bias :
66
+ self .register_buffer ("bias" , torch .zeros ((outfeatures ), dtype = torch .float16 ))
67
+ else :
68
+ self .bias = None
69
+
70
+ self .trainable = trainable
71
+
72
+ def post_init (self ):
73
+ pass
74
+
75
+ def pack (self , linear , scales , zeros , g_idx = None ):
76
+ scales_t = scales .t ().contiguous ()
77
+ if linear .bias is not None :
78
+ self .bias = linear .bias .clone ().half ()
79
+ self .scales = scales_t .clone ().half ()
80
+ device = "cpu"
81
+ if torch .cuda .is_available ():
82
+ device = "cuda:0"
83
+ elif torch .xpu .is_available ():
84
+ device = "xpu:0"
85
+
86
+ W = linear .weight .data .to (device ).clone ()
87
+ if isinstance (linear , nn .Conv2d ):
88
+ W = W .flatten (1 )
89
+ if isinstance (linear , transformers .pytorch_utils .Conv1D ):
90
+ W = W .t ()
91
+
92
+ repeat_scales = scales .to (device ).repeat_interleave (self .group_size , 1 )
93
+ if isinstance (zeros , torch .Tensor ):
94
+ repeat_zeros = zeros .to (device ).repeat_interleave (self .group_size , 1 )
95
+ intweight = torch .round (W .to (device ) / repeat_scales [:, :W .shape [1 ]] + repeat_zeros [:, :W .shape [1 ]]).to (
96
+ torch .int32 )
97
+ else :
98
+ repeat_zeros = zeros
99
+ intweight = torch .round (W .to (device ) / repeat_scales [:, :W .shape [1 ]] + repeat_zeros ).to (
100
+ torch .int32 )
101
+
102
+ del repeat_scales
103
+ intweight = intweight .reshape (- 1 , intweight .shape [1 ] // 32 * self .bits , 32 // self .bits )
104
+ order_map = torch .arange (0 , 32 // self .bits , device = device ) * self .bits
105
+ intweight = intweight << order_map
106
+ intweight = torch .sum (intweight , dim = - 1 )
107
+
108
+ intweight = intweight .t ().contiguous ().to (torch .int32 )
109
+ self .qweight = intweight .to ("cpu" )
110
+
111
+ if isinstance (zeros , torch .Tensor ):
112
+ zeros = zeros .t ().contiguous ()
113
+ zeros -= 1
114
+ zeros = zeros .numpy ().astype (np .uint32 )
115
+ qzeros = np .zeros ((zeros .shape [0 ], zeros .shape [1 ] // 32 * self .bits ), dtype = np .uint32 )
116
+ i = 0
117
+ col = 0
118
+ while col < qzeros .shape [1 ]:
119
+ for j in range (i , i + (32 // self .bits )):
120
+ qzeros [:, col ] |= zeros [:, j ] << (self .bits * (j - i ))
121
+ i += 32 // self .bits
122
+ col += 1
123
+
124
+ qzeros = qzeros .astype (np .int32 )
125
+ self .qzeros = torch .from_numpy (qzeros )
126
+ else :
127
+ zeros -= 1
128
+ shape = scales_t .shape
129
+ value = 0
130
+ for j in range (0 , (32 // self .bits )):
131
+ value |= zeros << (self .bits * j )
132
+ qzeros = np .ones ((shape [0 ], shape [1 ] // 32 * self .bits ), dtype = np .uint32 ) * value
133
+ qzeros = qzeros .astype (np .int32 )
134
+ self .qzeros = torch .from_numpy (qzeros )
135
+
136
+
137
+ __all__ = ["QuantLinear" ]
0 commit comments