Skip to content

Commit d5c2ed9

Browse files
authored
Merge pull request #81 from xyc0123456789/main
feat(chatglm3-6B): support chatglm3-6B
2 parents 405d866 + a1f1154 commit d5c2ed9

File tree

6 files changed

+305
-33
lines changed

6 files changed

+305
-33
lines changed

application/chatglm/chatglm.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ int main(int argc, char** argv) {
145145
if (params.version == 2) {
146146
model_name = "chatglm2";
147147
etoken = 2;
148+
}else if(params.version == 3){
149+
model_name = "chatglm3";
150+
etoken = 2;
148151
}
149152

150153
std::shared_ptr<inferllm::Model> model =

application/chatglm/convert.py

+165-32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Convert a ChatGLM model checkpoint to a InferLLM compatible file
1+
# Convert a chatglm model checkpoint to a InferLLM compatible file
22
#
33
# Load the model using Torch
44
# Iterate over all variables and write them to a binary file.
@@ -27,57 +27,86 @@
2727
# - Name (char[name_length])
2828
# - Data (int8_t[len])
2929
#
30+
#
3031
# By default, the bigger matrices are converted to 16-bit floats.
3132
# This can be disabled by adding the "use-f32" CLI argument.
3233
#
3334
# At the start of the ggml file we write the model parameters
3435
# and vocabulary.
35-
#
3636

3737
import sys
3838
import json
3939
import struct
40+
from enum import Enum
4041
import numpy as np
4142
import torch
4243
import argparse
4344
import tempfile
44-
from transformers import AutoTokenizer, AutoModel
45+
from transformers import AutoTokenizer, AutoModel, AutoConfig
4546
from sentencepiece import SentencePieceProcessor
4647

4748
# parse arguments
4849
parser = argparse.ArgumentParser(description="Convert a ChatGLM model to a InferLLM compatible fp16 data type file")
4950
parser.add_argument("-o", "--outfile", type=str, help="the output file")
5051
parser.add_argument("-v", "--version", type=int, default=1, help="the chatglm mode version")
52+
parser.add_argument("-q", "--quantization", type=int, default=32, help="quantization bits")
5153
args = parser.parse_args()
5254

5355
# output in the same directory as the model
5456
model_out_path = args.outfile
5557

56-
hparams = {
57-
"embd_size": 4096,
58-
"n_heads": 32,
59-
"n_layers": 28,
60-
"fc_hidden": 16384,
61-
}
62-
dtype = 0
58+
class GGMLType(Enum):
59+
# src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L32
60+
F32 = 0
61+
F16 = 1
62+
QInt4 = 2
63+
# QUInt4 = 3
64+
QInt8 = 4
65+
66+
alignment_size = 32
67+
bits = args.quantization
68+
if bits == 32:
69+
dtype = GGMLType.F32
70+
elif bits == 16:
71+
dtype = GGMLType.F16
72+
raise NotImplementedError(f"kernel not suport bits: {bits}")
73+
elif bits == 8:
74+
dtype = GGMLType.QInt8
75+
elif bits == 4:
76+
dtype = GGMLType.QInt4
77+
else:
78+
raise NotImplementedError(f"Unknown quantization bits: {bits}")
79+
6380
version = args.version
6481
if version == 1:
6582
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float().state_dict()
6683
auto_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
84+
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
6785
elif version == 2:
6886
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).float().state_dict()
6987
auto_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
88+
config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
89+
elif version == 3:
90+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).float().state_dict()
91+
auto_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
92+
config = AutoConfig.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
7093

7194
_, vocab_file = tempfile.mkstemp()
7295
auto_tokenizer.save_vocabulary(vocab_file)
7396
tokenizer = SentencePieceProcessor(vocab_file)
7497

98+
hparams = {
99+
"embd_size": config.hidden_size,
100+
"n_heads": config.num_attention_heads,
101+
"n_layers": config.num_layers,
102+
}
75103
hparams.update({"vocab_size": tokenizer.vocab_size()})
76104

77-
if version == 2:
78-
hparams.update({"multi_qeury": 1})
79-
hparams.update({"attention_patition": 2})
80-
hparams.update({"fc_hidden": 13696})
105+
if version > 1:
106+
hparams.update({"multi_qeury": 1 if config.multi_query_attention else 0})
107+
hparams.update({"attention_patition": config.multi_query_group_num})
108+
hparams.update({"fc_hidden": config.ffn_hidden_size})
109+
81110

82111
print(hparams)
83112

@@ -91,7 +120,7 @@
91120
param_byte +=struct.pack("i", hparams["n_layers"])
92121
param_byte +=struct.pack("i", hparams["fc_hidden"])
93122
param_byte +=struct.pack("i", hparams["vocab_size"])
94-
if version == 2:
123+
if version > 1:
95124
param_byte +=struct.pack("i", hparams["multi_qeury"])
96125
param_byte +=struct.pack("i", hparams["attention_patition"])
97126

@@ -150,32 +179,136 @@
150179
# seek to the end of the file
151180
fout.seek(0, 2)
152181

153-
for k, v in model.items():
154-
name = k
155-
shape = v.shape
182+
183+
184+
GGML_QK8_0 = 32
185+
GGML_QK4_0 = 32
186+
GGML_QK4_1 = 32
187+
188+
189+
GGML_MEM_ALIGN = 16
190+
191+
def float32Toint8(tensor):
192+
oriShape = tensor.shape
193+
newLastElement = oriShape[-1] * 4
194+
newShape = oriShape[:-1] + (newLastElement,)
195+
tensor_bytes = tensor.numpy().tobytes()
196+
return torch.tensor(np.frombuffer(tensor_bytes, dtype=np.int8)).view(newShape)
197+
198+
def offset(tensor, alignment):
199+
# 计算tensor所占用的字节数
200+
num_bytes = tensor.element_size() * tensor.nelement()
201+
# 计算需要填充的字节数
202+
padding = (alignment - (num_bytes % alignment)) % alignment
203+
return num_bytes+padding, padding
204+
def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor:
205+
"""
206+
src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L51
207+
"""
208+
# equivalent to ggml_quantize_q8_0 in ggml.c
209+
210+
if len(tensor.shape) == 1:
211+
tensor = tensor.unsqueeze(0)
212+
assert tensor.shape[1] % GGML_QK8_0 == 0
213+
tensor = tensor.view(-1, GGML_QK8_0)
214+
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
215+
tensor = (tensor / scale).round().clamp(min=-128, max=127).type(torch.int8)
216+
# add scale into each block
217+
tensor = torch.cat((float32Toint8(scale.float()), tensor), dim=-1)
218+
return tensor
219+
220+
def quantize_quint4(tensor: torch.Tensor) -> torch.Tensor:
221+
"""
222+
src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L62
223+
"""
224+
# equivalent to ggml_quantize_q4_0 in ggml.c
225+
if len(tensor.shape) == 1:
226+
tensor = tensor.unsqueeze(0)
227+
assert tensor.shape[1] % GGML_QK4_0 == 0
228+
tensor = tensor.view(-1, GGML_QK4_0)
229+
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
230+
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
231+
scale = max_values / -8
232+
tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char()
233+
# compress two int4 weights into an int8
234+
tensor = tensor[:, :16] | (tensor[:, 16:] << 4).type(torch.int8)
235+
# add scale into each block
236+
tensor = torch.cat((float32Toint8(scale.float()), tensor), dim=-1)
237+
return tensor
238+
239+
240+
241+
def quantize_qint4(tensor: torch.Tensor) -> torch.Tensor:
242+
"""
243+
src: https://github.com/li-plus/chatglm.cpp/blob/04910ce72a5d22087ec6e404dbefd73c1ccf2700/chatglm_cpp/convert.py#L62
244+
"""
245+
# equivalent to ggml_quantize_q4_0 in ggml.c
246+
if len(tensor.shape) == 1:
247+
tensor = tensor.unsqueeze(0)
248+
assert tensor.shape[1] % GGML_QK4_0 == 0
249+
tensor = tensor.view(-1, GGML_QK4_0)
250+
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
251+
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
252+
scale = max_values / -8
253+
tensor = (tensor / scale).round().clamp(min=-8, max=7).char()
254+
# compress two int4 weights into an int8
255+
tensor = tensor[:, :16] | (tensor[:, 16:] << 4).type(torch.int8)
256+
# add scale into each block
257+
tensor = torch.cat((float32Toint8(scale.float()), tensor), dim=-1)
258+
return tensor
259+
260+
def dump_tensor(f, name: str, tensor: torch.Tensor, ggml_type: GGMLType):
261+
assert tensor.dtype == torch.float32
262+
shape = tensor.shape
156263

157264
# skip layers.X.attention.inner_attention.rope.freqs
158-
if name[-5:] == "freqs":
159-
continue
265+
if name[-5:] == "freqs" or name[-4:]=="freq":
266+
return
267+
160268

161-
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
162269
if name.endswith("query_key_value.weight") or name.endswith("attention.query_key_value.bias"):
163270
if version == 1:
164-
v = v.reshape(32, 3, -1).transpose(0, 1).reshape(-1, 4096)
271+
tensor = tensor.reshape(32, 3, -1).transpose(0, 1).reshape(-1, 4096)
272+
dshape = tensor.shape
273+
sname = name.encode('utf-8')
165274

166-
data = v.numpy().squeeze()
167-
n_dims = len(data.shape)
168275

169-
dshape = data.shape
170-
sname = name.encode('utf-8')
171-
print("write tensor: ", name, " to file :", fout.tell())
172-
fout.write(struct.pack("iii", n_dims, len(sname), dtype))
276+
if "layernorm" not in name:
277+
# tensor data
278+
if ggml_type == GGMLType.F32:
279+
tensor = tensor.float()
280+
elif ggml_type == GGMLType.F16:
281+
tensor = tensor.half()
282+
elif ggml_type == GGMLType.QInt8:
283+
tensor = quantize_q8_0(tensor)
284+
elif ggml_type == GGMLType.QInt4:
285+
tensor = quantize_qint4(tensor)
286+
else:
287+
raise NotImplementedError(f"Cannot dump tensor of dtype {tensor.dtype}")
288+
else:
289+
tensor = tensor.float()
290+
ggml_type = GGMLType.F32
291+
292+
n_dims = len(shape)
293+
print("Processing variable: " + name + " with shape: ", shape, " and type: ", ggml_type.value)
294+
f.write(struct.pack("iii", n_dims, len(sname), ggml_type.value))
173295
for i in range(n_dims):
174-
fout.write(struct.pack("i", dshape[i]))
175-
fout.write(sname)
296+
f.write(struct.pack("i", dshape[i]))
297+
f.write(sname)
298+
print("write tensor: ", name, " to file :", f.tell())
299+
300+
tensor.numpy().tofile(f)
301+
# align address
302+
if ggml_type == GGMLType.QInt8 or ggml_type == GGMLType.QInt4:
303+
length, paddingSize =offset(tensor, alignment_size)
304+
if paddingSize>0:
305+
paddingTensor = torch.zeros(paddingSize)
306+
paddingTensor.numpy().tofile(f)
307+
print("write paddingTensor: ", name, "paddingSize:", paddingSize," to file :", f.tell())
308+
309+
for k, v in model.items():
310+
dump_tensor(fout, k, v, dtype)
176311

177-
# data
178-
data.tofile(fout)
179312

180313
# I hope this deallocates the memory ..
181314
model = None

src/core/graph.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ void Graph::load(
320320
"Error weight is not found when loading.");
321321
auto weight = m_weights_map[alias_name];
322322
if (weight->length() != nr_number) {
323-
INFER_LOG("weight %s is not match.\n", alias_name.c_str());
323+
INFER_LOG("weight %s %zu is not match.\n", alias_name.c_str(), weight->length());
324324
}
325325
INFER_ASSERT(
326326
weight->length() == nr_number, "Error length of weight is mismatch.");

src/graph/chatGLM.h

+13
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ class ChatGLMGraph : public Graph {
4141
class ChatGLMGraph2 : public Graph {
4242
using Graph::Graph;
4343

44+
public:
45+
void set_weights_alias() override;
46+
void construct_llm() override;
47+
void load_param(
48+
std::shared_ptr<InputFile> fin, LlmParams& param,
49+
std::shared_ptr<Vocab> vocab) override;
50+
51+
void post_tokenize(std::vector<Vocab::Id>& input) override;
52+
};
53+
54+
class ChatGLMGraph3 : public Graph {
55+
using Graph::Graph;
56+
4457
public:
4558
void set_weights_alias() override;
4659
void construct_llm() override;

0 commit comments

Comments
 (0)