16
16
17
17
import torch
18
18
from transformers import AutoTokenizer , AutoModelForCausalLM , BitsAndBytesConfig , TrainerCallback
19
+ from peft import prepare_model_for_kbit_training , PeftModel , LoraConfig , get_peft_model
20
+ from accelerate import FullyShardedDataParallelPlugin , Accelerator
21
+ from torch .distributed .fsdp .fully_sharded_data_parallel import FullOptimStateDictConfig , FullStateDictConfig
19
22
20
23
from gguf_loader .convert import main as gguf_main
21
24
@@ -168,13 +171,9 @@ def generate_and_tokenize_prompt(prompt):
168
171
169
172
model = AutoModelForCausalLM .from_pretrained (base_model_id , quantization_config = bnb_config , device_map = "auto" , resume_download = True )
170
173
171
- from peft import prepare_model_for_kbit_training , PeftModel
172
-
173
174
model .gradient_checkpointing_enable ()
174
175
model = prepare_model_for_kbit_training (model )
175
176
176
- from peft import LoraConfig , get_peft_model
177
-
178
177
config = LoraConfig (
179
178
r = 32 ,
180
179
lora_alpha = hp .get ("lora_alpha" , 64 ),
@@ -195,9 +194,6 @@ def generate_and_tokenize_prompt(prompt):
195
194
196
195
model = get_peft_model (model , config )
197
196
198
- from accelerate import FullyShardedDataParallelPlugin , Accelerator
199
- from torch .distributed .fsdp .fully_sharded_data_parallel import FullOptimStateDictConfig , FullStateDictConfig
200
-
201
197
fsdp_plugin = FullyShardedDataParallelPlugin (
202
198
state_dict_config = FullStateDictConfig (offload_to_cpu = True , rank0_only = False ),
203
199
optim_state_dict_config = FullOptimStateDictConfig (offload_to_cpu = True , rank0_only = False ),
@@ -211,7 +207,7 @@ def generate_and_tokenize_prompt(prompt):
211
207
model .is_parallelizable = True
212
208
model .model_parallel = True
213
209
214
- project = "journal- finetune"
210
+ project = "finetune"
215
211
base_model_name = base_model_id .split ("/" )[- 1 ]
216
212
run_name = base_model_name + "-" + project + "-" + os .urandom (16 ).hex ()
217
213
output_dir = "./" + run_name
@@ -268,18 +264,18 @@ def on_save(self, args, state, control, **kwargs):
268
264
269
265
tmp = self .temp_file (run_name , wipe = True )
270
266
tokenizer .save_pretrained (tmp )
267
+ log .info ("SAVED" , os .listdir (tmp ))
271
268
272
- self .return_final (run_name , model , cb )
269
+ self .return_final (run_name , model , base_model_id , cb )
273
270
274
- def return_final (self , run_name , model , cb ):
271
+ def return_final (self , run_name , model , base_model_id , cb ):
275
272
log .info ("return final" )
276
273
277
274
tmp = self .temp_file (run_name )
278
275
279
276
# send up lora
280
277
model .save_pretrained (tmp , safe_serialization = True )
281
278
gz = gzip (tmp )
282
- shutil .rmtree (tmp )
283
279
with open (gz , "rb" ) as fil :
284
280
while True :
285
281
dat = fil .read (100000 )
@@ -296,10 +292,13 @@ def return_final(self, run_name, model, cb):
296
292
del model
297
293
gc .collect ()
298
294
295
+ # reload with f16
299
296
model = PeftModel .from_pretrained (AutoModelForCausalLM .from_pretrained (base_model_id , torch_dtype = torch .float16 , local_files_only = True , device_map = "auto" ), tmp )
300
297
model = model .merge_and_unload ()
301
298
302
299
gc .collect ()
300
+
301
+ shutil .rmtree (tmp )
303
302
model .save_pretrained (tmp )
304
303
305
304
# convert to gguf for fast inference
0 commit comments