1
+ import gc
1
2
import base64
2
3
import hashlib
3
4
import json
5
+ import asyncio
6
+ import threading
4
7
import logging
5
8
import os
6
9
import random
10
+ import tarfile
11
+ import shutil
7
12
8
13
import transformers
9
14
from datasets import load_dataset
10
15
from httpx import AsyncClient , Response
11
16
12
17
import torch
13
- from transformers import AutoTokenizer , AutoModelForCausalLM , BitsAndBytesConfig
18
+ from transformers import AutoTokenizer , AutoModelForCausalLM , BitsAndBytesConfig , TrainerCallback
14
19
15
- from ai_worker . jsonlines import load_jsonlines
20
+ from gguf_loader . convert import main as gguf_main
16
21
17
22
MAX_CONTEXT = 300000
18
23
19
24
log = logging .getLogger (__name__ )
20
25
26
+ def gzip (folder ):
27
+ """tar gz the folder to 'folder.tar.gz', removes the folder"""
28
+ base_folder_name = os .path .basename (folder )
29
+ with tarfile .open (f"{ folder } .tar.gz" , 'w:gz' ) as archive :
30
+ archive .add (folder , arcname = base_folder_name )
31
+ return f"{ folder } .tar.gz"
21
32
22
33
class FineTuner :
23
34
def __init__ (self , conf ):
24
35
self .conf = conf
25
36
os .makedirs (self .conf .tmp_dir , exist_ok = True )
26
37
27
- def temp_file (self , name ):
28
- return os .path .join (self .conf .tmp_dir , name )
38
+ def temp_file (self , name , wipe = False ):
39
+ ret = os .path .join (self .conf .tmp_dir , name )
40
+ if wipe :
41
+ shutil .rmtree (ret , ignore_errors = True )
42
+ return ret
29
43
30
44
def massage_line (self , ln , job ):
45
+ # toss our role for now, for some reason it didn't work
46
+ # todo: check for role support in template
47
+ if "mistral" in job ["model" ].lower ():
48
+ j = json .loads (ln )
49
+ j ["messages" ] = [m for m in j ["messages" ] if m ["role" ] != "system" ]
50
+ ln = json .dumps (j ) + "\n "
31
51
return ln
32
52
33
53
def massage_fine_tune (self , file , job ):
34
54
cnt = 0
55
+ tc = 0
56
+ ec = 0
35
57
training_split_pct = job .get ("hyperparameters" , {}).get ("training_split" , 0.8 )
36
58
37
59
train_file = file + ".train"
@@ -41,42 +63,70 @@ def massage_fine_tune(self, file, job):
41
63
with open (eval_file , "w" ) as ef :
42
64
with open (file , "r" ) as inp :
43
65
ln = inp .readline (MAX_CONTEXT )
44
- ln = self .massage_line (ln , job )
45
66
while ln :
67
+ ln = self .massage_line (ln , job )
46
68
cnt += 1
47
- if random .random () > training_split_pct :
69
+ if ec and (random .random () > training_split_pct or tc <= ec ):
70
+ tc += 1
48
71
tf .write (ln )
49
72
else :
73
+ ec += 1
50
74
ef .write (ln )
51
75
ln = inp .readline (MAX_CONTEXT )
52
- ln = self .massage_line (ln , job )
53
76
return train_file , eval_file
54
77
55
78
async def fine_tune (self , job ):
56
79
log .info ("fine tuning: %s" , job )
57
80
58
- yield {"status" : "downloading_data " }
81
+ yield {"status" : "download_data " }
59
82
60
- base_model = job ["model" ]
61
83
training_url = job ["training_file" ]
62
84
training_file = await self .download_file (training_url )
63
-
85
+ job ["training_file" ] = training_file
86
+
87
+ q = asyncio .Queue ()
88
+
89
+ loop = asyncio .get_running_loop ()
90
+
91
+ t = threading .Thread (target = lambda : self ._fine_tune (job , lambda res : loop .call_soon_threadsafe (q .put_nowait , res )), daemon = True )
92
+
93
+ t .start ()
94
+ while True :
95
+ res = await q .get ()
96
+ if res is None :
97
+ break
98
+ yield res
99
+ log .info ("DONE" )
100
+ t .join ()
101
+
102
+ def _fine_tune (self , job , cb ):
103
+ try :
104
+ self ._unsafe_fine_tune (job , cb )
105
+ except Exception as ex :
106
+ log .exception ("error in fine tune" )
107
+ cb ({"status" : "error" , "detail" : repr (ex )})
108
+ finally :
109
+ cb (None )
110
+
111
+ def _unsafe_fine_tune (self , job , cb ):
112
+ training_file = job ["training_file" ]
64
113
train_file , eval_file = self .massage_fine_tune (training_file , job )
65
114
66
- train_dataset = load_jsonlines ( open ( train_file ))
67
- eval_dataset = load_jsonlines ( open ( eval_file ) )
68
-
69
- # todo: use user's model request
115
+ base_model = job [ "model" ]
116
+ datasets = load_dataset ( "json" , data_files = { "train" : train_file , "eval" : eval_file } )
117
+ train_dataset = datasets [ "train" ]
118
+ eval_dataset = datasets [ "eval" ]
70
119
71
- base_model_id = "mistralai/Mistral-7B-v0.1"
120
+ base_model_id = base_model . split ( ":" )[ 0 ]
72
121
73
122
# todo: use hyperparams and Q_ filter, if present, for this
74
123
75
124
hp = job .get ("hyperparameters" , {})
76
125
77
126
args = {}
78
127
79
- yield {"status" : "loading_model" }
128
+ log .info ("load model" )
129
+ cb ({"status" : "load_model" })
80
130
81
131
args .update (dict (
82
132
load_in_4bit = True ,
@@ -89,25 +139,23 @@ async def fine_tune(self, job):
89
139
90
140
# todo: ideally we use llama cpp, but the cuda support for finetune isn't there
91
141
92
- model = AutoModelForCausalLM .from_pretrained (base_model_id , quantization_config = bnb_config , device_map = "auto" )
93
-
94
142
tokenizer = AutoTokenizer .from_pretrained (
95
143
base_model_id ,
96
144
padding_side = "left" ,
97
145
add_eos_token = True ,
98
146
add_bos_token = True ,
99
147
)
100
148
101
- train_dataset = tokenizer .apply_chat_template (train_dataset )
102
- eval_dataset = tokenizer .apply_chat_template (eval_dataset )
103
-
149
+ # sadly, does not take generators, just loads everything in ram
104
150
tokenizer .pad_token = tokenizer .eos_token
105
-
106
- max_length = 512
107
-
151
+ # todo: derive from model params
152
+ max_length = 4096
108
153
def generate_and_tokenize_prompt (prompt ):
154
+ # all input is openai formatted, and we clean it up above if needed
155
+ pr = prompt ["messages" ]
156
+ tmpl = tokenizer .apply_chat_template (pr , tokenize = False )
109
157
result = tokenizer (
110
- prompt ,
158
+ tmpl ,
111
159
truncation = True ,
112
160
max_length = max_length ,
113
161
padding = "max_length" ,
@@ -118,7 +166,9 @@ def generate_and_tokenize_prompt(prompt):
118
166
tokenized_train_dataset = train_dataset .map (generate_and_tokenize_prompt )
119
167
tokenized_val_dataset = eval_dataset .map (generate_and_tokenize_prompt )
120
168
121
- from peft import prepare_model_for_kbit_training
169
+ model = AutoModelForCausalLM .from_pretrained (base_model_id , quantization_config = bnb_config , device_map = "auto" , resume_download = True )
170
+
171
+ from peft import prepare_model_for_kbit_training , PeftModel
122
172
123
173
model .gradient_checkpointing_enable ()
124
174
model = prepare_model_for_kbit_training (model )
@@ -127,7 +177,7 @@ def generate_and_tokenize_prompt(prompt):
127
177
128
178
config = LoraConfig (
129
179
r = 32 ,
130
- lora_alpha = 64 ,
180
+ lora_alpha = hp . get ( "lora_alpha" , 64 ) ,
131
181
target_modules = [
132
182
"q_proj" ,
133
183
"k_proj" ,
@@ -139,7 +189,7 @@ def generate_and_tokenize_prompt(prompt):
139
189
"lm_head" ,
140
190
],
141
191
bias = "none" ,
142
- lora_dropout = 0.05 , # Conventional
192
+ lora_dropout = hp . get ( "lora_dropout" , 0.05 ) , # Conventional
143
193
task_type = "CAUSAL_LM" ,
144
194
)
145
195
@@ -162,41 +212,114 @@ def generate_and_tokenize_prompt(prompt):
162
212
model .model_parallel = True
163
213
164
214
project = "journal-finetune"
165
- base_model_name = "mistral"
166
- run_name = base_model_name + "-" + project
215
+ base_model_name = base_model_id . split ( "/" )[ - 1 ]
216
+ run_name = base_model_name + "-" + project + "-" + os . urandom ( 16 ). hex ()
167
217
output_dir = "./" + run_name
168
218
169
219
tokenizer .pad_token = tokenizer .eos_token
170
220
221
+ class EarlyStoppingCallback (TrainerCallback ):
222
+ def on_log (self , args , state , control , logs = None , ** kwargs ):
223
+ cb ({"status" : "log" , "logs" : logs })
224
+ eval_loss = logs .get ("eval_loss" , None )
225
+ if eval_loss is not None and eval_loss <= hp .get ("stop_eval_loss" , 0.05 ):
226
+ print ("Early stopping criterion reached!" )
227
+ control .should_training_stop = True
228
+
229
+ def on_save (self , args , state , control , ** kwargs ):
230
+ checkpoint_dir = os .path .join (args .output_dir , f"checkpoint-{ state .global_step } " )
231
+ log .info (f"checkpoint { checkpoint_dir } " )
232
+ cb ({"status" : "checkpoint" })
233
+
234
+
171
235
trainer = transformers .Trainer (
172
236
model = model ,
173
237
train_dataset = tokenized_train_dataset ,
174
238
eval_dataset = tokenized_val_dataset ,
239
+ callbacks = [EarlyStoppingCallback ()],
175
240
args = transformers .TrainingArguments (
176
241
output_dir = output_dir ,
177
242
warmup_steps = 1 ,
178
- per_device_train_batch_size = 2 ,
179
- gradient_accumulation_steps = 1 ,
180
- max_steps = 500 ,
181
- learning_rate = 2.5e-5 , # Want a small lr for finetuning
243
+ per_device_train_batch_size = hp .get ("batch_size" , 4 ),
244
+ gradient_accumulation_steps = hp .get ("accumulation_steps" , 4 ),
245
+ max_steps = hp .get ("max_steps" , - 1 ),
246
+ num_train_epochs = hp .get ("n_epochs" , 3 ), # use openai terminology here
247
+ learning_rate = hp .get ("learning_rate_multiplier" , 2.5e-5 ), # Want a small lr for finetuning
182
248
bf16 = True ,
183
249
optim = "paged_adamw_8bit" ,
184
250
logging_steps = 25 , # When to start reporting loss
185
251
logging_dir = "./logs" , # Directory for storing logs
186
252
save_strategy = "steps" , # Save the model checkpoint every logging step
187
253
save_steps = 25 , # Save checkpoints
254
+ save_total_limit = 5 , # Save checkpoints
255
+ load_best_model_at_end = True ,
188
256
evaluation_strategy = "steps" , # Evaluate the model every logging step
189
- eval_steps = 25 , # Evaluate and save checkpoints every 50 steps
257
+ eval_steps = 25 , # Evaluate and save checkpoints every 25 steps
190
258
do_eval = True , # Perform evaluation at the end of training
191
259
),
192
260
data_collator = transformers .DataCollatorForLanguageModeling (tokenizer , mlm = False ),
193
261
)
194
262
195
- model .config .use_cache = False # silence the warnings. Please re-enable for inference!
263
+ log .info ("start train" )
264
+ cb ({"status" : "start_train" })
265
+ model .config .use_cache = False # silence the warnings
266
+
196
267
trainer .train ()
197
268
198
- res = {"status" : "done" , "checkpoint" : str (base64 .b64encode (b"checkpoint" ))}
199
- yield res
269
+ tmp = self .temp_file (run_name , wipe = True )
270
+ tokenizer .save_pretrained (tmp )
271
+
272
+ self .return_final (run_name , model , cb )
273
+
274
+ def return_final (self , run_name , model , cb ):
275
+ log .info ("return final" )
276
+
277
+ tmp = self .temp_file (run_name )
278
+
279
+ # send up lora
280
+ model .save_pretrained (tmp , safe_serialization = True )
281
+ gz = gzip (tmp )
282
+ shutil .rmtree (tmp )
283
+ with open (gz , "rb" ) as fil :
284
+ while True :
285
+ dat = fil .read (100000 )
286
+ if not dat :
287
+ break ;
288
+ res = {"status" : "lora" , "chunk" : str (base64 .b64encode (dat ))}
289
+ cb (res )
290
+
291
+ log .info ("merge weights" )
292
+
293
+ # merge weights
294
+
295
+ # reload as float16 for merge
296
+ del model
297
+ gc .collect ()
298
+
299
+ model = PeftModel .from_pretrained (AutoModelForCausalLM .from_pretrained (base_model_id , torch_dtype = torch .float16 , local_files_only = True , device_map = "auto" ), tmp )
300
+ model = model .merge_and_unload ()
301
+
302
+ gc .collect ()
303
+ model .save_pretrained (tmp )
304
+
305
+ # convert to gguf for fast inference
306
+ log .info ("ggml convert" )
307
+ gguf_main ([tmp ])
308
+ gg = tmp + "/ggml-model-f16.gguf"
309
+ with open (gg , "rb" ) as fil :
310
+ while True :
311
+ dat = fil .read (100000 )
312
+ if not dat :
313
+ break ;
314
+ res = {"status" : "gguf" , "chunk" : str (base64 .b64encode (dat ))}
315
+ cb (res )
316
+
317
+ shutil .rmtree (tmp )
318
+ shutil .rmtree (output_dir )
319
+
320
+ res = {"status" : "done" }
321
+ log .info ("done train" )
322
+ cb (res )
200
323
201
324
async def download_file (self , training_url : str ) -> str :
202
325
output_file = self .temp_file (hashlib .md5 (training_url .encode ()).hexdigest ())
0 commit comments