|
1 |
| -# Convert a ChatGLM model checkpoint to a InferLLM compatible file |
| 1 | +# Convert a chatglm model checkpoint to a InferLLM compatible file |
2 | 2 | #
|
3 | 3 | # Load the model using Torch
|
4 | 4 | # Iterate over all variables and write them to a binary file.
|
|
27 | 27 | # - Name (char[name_length])
|
28 | 28 | # - Data (int8_t[len])
|
29 | 29 | #
|
| 30 | +# |
30 | 31 | # By default, the bigger matrices are converted to 16-bit floats.
|
31 | 32 | # This can be disabled by adding the "use-f32" CLI argument.
|
32 | 33 | #
|
33 | 34 | # At the start of the ggml file we write the model parameters
|
34 | 35 | # and vocabulary.
|
35 |
| -# |
36 | 36 |
|
37 | 37 | import sys
|
38 | 38 | import json
|
39 | 39 | import struct
|
| 40 | +from enum import Enum |
40 | 41 | import numpy as np
|
41 | 42 | import torch
|
42 | 43 | import argparse
|
43 | 44 | import tempfile
|
44 |
| -from transformers import AutoTokenizer, AutoModel |
| 45 | +from transformers import AutoTokenizer, AutoModel, AutoConfig |
45 | 46 | from sentencepiece import SentencePieceProcessor
|
46 | 47 |
|
47 | 48 | # parse arguments
|
48 | 49 | parser = argparse.ArgumentParser(description="Convert a ChatGLM model to a InferLLM compatible fp16 data type file")
|
49 | 50 | parser.add_argument("-o", "--outfile", type=str, help="the output file")
|
50 | 51 | parser.add_argument("-v", "--version", type=int, default=1, help="the chatglm mode version")
|
| 52 | +parser.add_argument("-q", "--quantization", type=int, default=32, help="quantization bits") |
51 | 53 | args = parser.parse_args()
|
52 | 54 |
|
53 | 55 | # output in the same directory as the model
|
54 | 56 | model_out_path = args.outfile
|
55 | 57 |
|
56 |
| -hparams = { |
57 |
| - "embd_size": 4096, |
58 |
| - "n_heads": 32, |
59 |
| - "n_layers": 28, |
60 |
| - "fc_hidden": 16384, |
61 |
| -} |
62 |
| -dtype = 0 |
| 58 | +class GGMLType(Enum): |
| 59 | + # src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L32 |
| 60 | + F32 = 0 |
| 61 | + F16 = 1 |
| 62 | + QInt4 = 2 |
| 63 | + # QUInt4 = 3 |
| 64 | + QInt8 = 4 |
| 65 | + |
| 66 | +alignment_size = 32 |
| 67 | +bits = args.quantization |
| 68 | +if bits == 32: |
| 69 | + dtype = GGMLType.F32 |
| 70 | +elif bits == 16: |
| 71 | + dtype = GGMLType.F16 |
| 72 | + raise NotImplementedError(f"kernel not suport bits: {bits}") |
| 73 | +elif bits == 8: |
| 74 | + dtype = GGMLType.QInt8 |
| 75 | +elif bits == 4: |
| 76 | + dtype = GGMLType.QInt4 |
| 77 | +else: |
| 78 | + raise NotImplementedError(f"Unknown quantization bits: {bits}") |
| 79 | + |
63 | 80 | version = args.version
|
64 | 81 | if version == 1:
|
65 | 82 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float().state_dict()
|
66 | 83 | auto_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
| 84 | + config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
67 | 85 | elif version == 2:
|
68 | 86 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).float().state_dict()
|
69 | 87 | auto_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
| 88 | + config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) |
| 89 | +elif version == 3: |
| 90 | + model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).float().state_dict() |
| 91 | + auto_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True) |
| 92 | + config = AutoConfig.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True) |
70 | 93 |
|
71 | 94 | _, vocab_file = tempfile.mkstemp()
|
72 | 95 | auto_tokenizer.save_vocabulary(vocab_file)
|
73 | 96 | tokenizer = SentencePieceProcessor(vocab_file)
|
74 | 97 |
|
| 98 | +hparams = { |
| 99 | + "embd_size": config.hidden_size, |
| 100 | + "n_heads": config.num_attention_heads, |
| 101 | + "n_layers": config.num_layers, |
| 102 | +} |
75 | 103 | hparams.update({"vocab_size": tokenizer.vocab_size()})
|
76 | 104 |
|
77 |
| -if version == 2: |
78 |
| - hparams.update({"multi_qeury": 1}) |
79 |
| - hparams.update({"attention_patition": 2}) |
80 |
| - hparams.update({"fc_hidden": 13696}) |
| 105 | +if version > 1: |
| 106 | + hparams.update({"multi_qeury": 1 if config.multi_query_attention else 0}) |
| 107 | + hparams.update({"attention_patition": config.multi_query_group_num}) |
| 108 | + hparams.update({"fc_hidden": config.ffn_hidden_size}) |
| 109 | + |
81 | 110 |
|
82 | 111 | print(hparams)
|
83 | 112 |
|
|
91 | 120 | param_byte +=struct.pack("i", hparams["n_layers"])
|
92 | 121 | param_byte +=struct.pack("i", hparams["fc_hidden"])
|
93 | 122 | param_byte +=struct.pack("i", hparams["vocab_size"])
|
94 |
| -if version == 2: |
| 123 | +if version > 1: |
95 | 124 | param_byte +=struct.pack("i", hparams["multi_qeury"])
|
96 | 125 | param_byte +=struct.pack("i", hparams["attention_patition"])
|
97 | 126 |
|
|
150 | 179 | # seek to the end of the file
|
151 | 180 | fout.seek(0, 2)
|
152 | 181 |
|
153 |
| -for k, v in model.items(): |
154 |
| - name = k |
155 |
| - shape = v.shape |
| 182 | + |
| 183 | + |
| 184 | +GGML_QK8_0 = 32 |
| 185 | +GGML_QK4_0 = 32 |
| 186 | +GGML_QK4_1 = 32 |
| 187 | + |
| 188 | + |
| 189 | +GGML_MEM_ALIGN = 16 |
| 190 | + |
| 191 | +def float32Toint8(tensor): |
| 192 | + oriShape = tensor.shape |
| 193 | + newLastElement = oriShape[-1] * 4 |
| 194 | + newShape = oriShape[:-1] + (newLastElement,) |
| 195 | + tensor_bytes = tensor.numpy().tobytes() |
| 196 | + return torch.tensor(np.frombuffer(tensor_bytes, dtype=np.int8)).view(newShape) |
| 197 | + |
| 198 | +def offset(tensor, alignment): |
| 199 | + # 计算tensor所占用的字节数 |
| 200 | + num_bytes = tensor.element_size() * tensor.nelement() |
| 201 | + # 计算需要填充的字节数 |
| 202 | + padding = (alignment - (num_bytes % alignment)) % alignment |
| 203 | + return num_bytes+padding, padding |
| 204 | +def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor: |
| 205 | + """ |
| 206 | + src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L51 |
| 207 | + """ |
| 208 | + # equivalent to ggml_quantize_q8_0 in ggml.c |
| 209 | + |
| 210 | + if len(tensor.shape) == 1: |
| 211 | + tensor = tensor.unsqueeze(0) |
| 212 | + assert tensor.shape[1] % GGML_QK8_0 == 0 |
| 213 | + tensor = tensor.view(-1, GGML_QK8_0) |
| 214 | + scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) |
| 215 | + tensor = (tensor / scale).round().clamp(min=-128, max=127).type(torch.int8) |
| 216 | + # add scale into each block |
| 217 | + tensor = torch.cat((float32Toint8(scale.float()), tensor), dim=-1) |
| 218 | + return tensor |
| 219 | + |
| 220 | +def quantize_quint4(tensor: torch.Tensor) -> torch.Tensor: |
| 221 | + """ |
| 222 | + src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L62 |
| 223 | + """ |
| 224 | + # equivalent to ggml_quantize_q4_0 in ggml.c |
| 225 | + if len(tensor.shape) == 1: |
| 226 | + tensor = tensor.unsqueeze(0) |
| 227 | + assert tensor.shape[1] % GGML_QK4_0 == 0 |
| 228 | + tensor = tensor.view(-1, GGML_QK4_0) |
| 229 | + abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices |
| 230 | + max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) |
| 231 | + scale = max_values / -8 |
| 232 | + tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char() |
| 233 | + # compress two int4 weights into an int8 |
| 234 | + tensor = tensor[:, :16] | (tensor[:, 16:] << 4).type(torch.int8) |
| 235 | + # add scale into each block |
| 236 | + tensor = torch.cat((float32Toint8(scale.float()), tensor), dim=-1) |
| 237 | + return tensor |
| 238 | + |
| 239 | + |
| 240 | + |
| 241 | +def quantize_qint4(tensor: torch.Tensor) -> torch.Tensor: |
| 242 | + """ |
| 243 | + src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L62 |
| 244 | + """ |
| 245 | + # equivalent to ggml_quantize_q4_0 in ggml.c |
| 246 | + if len(tensor.shape) == 1: |
| 247 | + tensor = tensor.unsqueeze(0) |
| 248 | + assert tensor.shape[1] % GGML_QK4_0 == 0 |
| 249 | + tensor = tensor.view(-1, GGML_QK4_0) |
| 250 | + abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices |
| 251 | + max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) |
| 252 | + scale = max_values / -8 |
| 253 | + tensor = (tensor / scale).round().clamp(min=-8, max=7).char() |
| 254 | + # compress two int4 weights into an int8 |
| 255 | + tensor = tensor[:, :16] | (tensor[:, 16:] << 4).type(torch.int8) |
| 256 | + # add scale into each block |
| 257 | + tensor = torch.cat((float32Toint8(scale.float()), tensor), dim=-1) |
| 258 | + return tensor |
| 259 | + |
| 260 | +def dump_tensor(f, name: str, tensor: torch.Tensor, ggml_type: GGMLType): |
| 261 | + assert tensor.dtype == torch.float32 |
| 262 | + shape = tensor.shape |
156 | 263 |
|
157 | 264 | # skip layers.X.attention.inner_attention.rope.freqs
|
158 |
| - if name[-5:] == "freqs": |
159 |
| - continue |
| 265 | + if name[-5:] == "freqs" or name[-4:]=="freq": |
| 266 | + return |
| 267 | + |
160 | 268 |
|
161 |
| - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) |
162 | 269 | if name.endswith("query_key_value.weight") or name.endswith("attention.query_key_value.bias"):
|
163 | 270 | if version == 1:
|
164 |
| - v = v.reshape(32, 3, -1).transpose(0, 1).reshape(-1, 4096) |
| 271 | + tensor = tensor.reshape(32, 3, -1).transpose(0, 1).reshape(-1, 4096) |
| 272 | + dshape = tensor.shape |
| 273 | + sname = name.encode('utf-8') |
165 | 274 |
|
166 |
| - data = v.numpy().squeeze() |
167 |
| - n_dims = len(data.shape) |
168 | 275 |
|
169 |
| - dshape = data.shape |
170 |
| - sname = name.encode('utf-8') |
171 |
| - print("write tensor: ", name, " to file :", fout.tell()) |
172 |
| - fout.write(struct.pack("iii", n_dims, len(sname), dtype)) |
| 276 | + if "layernorm" not in name: |
| 277 | + # tensor data |
| 278 | + if ggml_type == GGMLType.F32: |
| 279 | + tensor = tensor.float() |
| 280 | + elif ggml_type == GGMLType.F16: |
| 281 | + tensor = tensor.half() |
| 282 | + elif ggml_type == GGMLType.QInt8: |
| 283 | + tensor = quantize_q8_0(tensor) |
| 284 | + elif ggml_type == GGMLType.QInt4: |
| 285 | + tensor = quantize_qint4(tensor) |
| 286 | + else: |
| 287 | + raise NotImplementedError(f"Cannot dump tensor of dtype {tensor.dtype}") |
| 288 | + else: |
| 289 | + tensor = tensor.float() |
| 290 | + ggml_type = GGMLType.F32 |
| 291 | + |
| 292 | + n_dims = len(shape) |
| 293 | + print("Processing variable: " + name + " with shape: ", shape, " and type: ", ggml_type.value) |
| 294 | + f.write(struct.pack("iii", n_dims, len(sname), ggml_type.value)) |
173 | 295 | for i in range(n_dims):
|
174 |
| - fout.write(struct.pack("i", dshape[i])) |
175 |
| - fout.write(sname) |
| 296 | + f.write(struct.pack("i", dshape[i])) |
| 297 | + f.write(sname) |
| 298 | + print("write tensor: ", name, " to file :", f.tell()) |
| 299 | + |
| 300 | + tensor.numpy().tofile(f) |
| 301 | + # align address |
| 302 | + if ggml_type == GGMLType.QInt8 or ggml_type == GGMLType.QInt4: |
| 303 | + length, paddingSize =offset(tensor, alignment_size) |
| 304 | + if paddingSize>0: |
| 305 | + paddingTensor = torch.zeros(paddingSize) |
| 306 | + paddingTensor.numpy().tofile(f) |
| 307 | + print("write paddingTensor: ", name, "paddingSize:", paddingSize," to file :", f.tell()) |
| 308 | + |
| 309 | +for k, v in model.items(): |
| 310 | + dump_tensor(fout, k, v, dtype) |
176 | 311 |
|
177 |
| - # data |
178 |
| - data.tofile(fout) |
179 | 312 |
|
180 | 313 | # I hope this deallocates the memory ..
|
181 | 314 | model = None
|
|
0 commit comments