Skip to content

Commit e8e4249

Browse files
committed
.
1 parent e3b132c commit e8e4249

File tree

6 files changed

+137
-128
lines changed

6 files changed

+137
-128
lines changed

ai_worker/fine_tune.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import torch
1818
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainerCallback
19+
from peft import prepare_model_for_kbit_training, PeftModel, LoraConfig, get_peft_model
20+
from accelerate import FullyShardedDataParallelPlugin, Accelerator
21+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
1922

2023
from gguf_loader.convert import main as gguf_main
2124

@@ -168,13 +171,9 @@ def generate_and_tokenize_prompt(prompt):
168171

169172
model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto", resume_download=True)
170173

171-
from peft import prepare_model_for_kbit_training, PeftModel
172-
173174
model.gradient_checkpointing_enable()
174175
model = prepare_model_for_kbit_training(model)
175176

176-
from peft import LoraConfig, get_peft_model
177-
178177
config = LoraConfig(
179178
r=32,
180179
lora_alpha=hp.get("lora_alpha", 64),
@@ -195,9 +194,6 @@ def generate_and_tokenize_prompt(prompt):
195194

196195
model = get_peft_model(model, config)
197196

198-
from accelerate import FullyShardedDataParallelPlugin, Accelerator
199-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
200-
201197
fsdp_plugin = FullyShardedDataParallelPlugin(
202198
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
203199
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
@@ -211,7 +207,7 @@ def generate_and_tokenize_prompt(prompt):
211207
model.is_parallelizable = True
212208
model.model_parallel = True
213209

214-
project = "journal-finetune"
210+
project = "finetune"
215211
base_model_name = base_model_id.split("/")[-1]
216212
run_name = base_model_name + "-" + project + "-" + os.urandom(16).hex()
217213
output_dir = "./" + run_name
@@ -268,18 +264,18 @@ def on_save(self, args, state, control, **kwargs):
268264

269265
tmp = self.temp_file(run_name, wipe=True)
270266
tokenizer.save_pretrained(tmp)
267+
log.info("SAVED", os.listdir(tmp))
271268

272-
self.return_final(run_name, model, cb)
269+
self.return_final(run_name, model, base_model_id, cb)
273270

274-
def return_final(self, run_name, model, cb):
271+
def return_final(self, run_name, model, base_model_id, cb):
275272
log.info("return final")
276273

277274
tmp = self.temp_file(run_name)
278275

279276
# send up lora
280277
model.save_pretrained(tmp, safe_serialization=True)
281278
gz = gzip(tmp)
282-
shutil.rmtree(tmp)
283279
with open(gz, "rb") as fil:
284280
while True:
285281
dat = fil.read(100000)
@@ -296,10 +292,13 @@ def return_final(self, run_name, model, cb):
296292
del model
297293
gc.collect()
298294

295+
# reload with f16
299296
model = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16, local_files_only=True, device_map="auto"), tmp)
300297
model = model.merge_and_unload()
301298

302299
gc.collect()
300+
301+
shutil.rmtree(tmp)
303302
model.save_pretrained(tmp)
304303

305304
# convert to gguf for fast inference

ai_worker/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _get_connect_info(self) -> ConnectMessage:
225225

226226
connect_msg = ConnectMessage(
227227
worker_version=VERSION,
228-
capabilitied=caps,
228+
capabilities=caps,
229229
worker_id=self.conf.worker_id,
230230
ln_url=self.conf.ln_address, # todo: remove eventually
231231
ln_address=self.conf.ln_address,

build-linux.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ set -o xtrace
66

77
./build-bin.sh opencl linux-64 "-DLLAMA_CLBLAST=ON"
88

9-
./build-bin.sh cuda-torch linux-64 "-DLLAMA_CLBLAST=ON"
9+
./build-bin.sh cuda-torch linux-64 "-DLLAMA_CUBLAS=ON"
1010

1111
./upload.sh

0 commit comments

Comments
 (0)