Skip to content

Commit e3b132c

Browse files
committed
e2e test worker
1 parent 38ef113 commit e3b132c

9 files changed

+286
-85
lines changed

ai_worker/fine_tune.py

+161-38
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,59 @@
1+
import gc
12
import base64
23
import hashlib
34
import json
5+
import asyncio
6+
import threading
47
import logging
58
import os
69
import random
10+
import tarfile
11+
import shutil
712

813
import transformers
914
from datasets import load_dataset
1015
from httpx import AsyncClient, Response
1116

1217
import torch
13-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
18+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainerCallback
1419

15-
from ai_worker.jsonlines import load_jsonlines
20+
from gguf_loader.convert import main as gguf_main
1621

1722
MAX_CONTEXT = 300000
1823

1924
log = logging.getLogger(__name__)
2025

26+
def gzip(folder):
27+
"""tar gz the folder to 'folder.tar.gz', removes the folder"""
28+
base_folder_name = os.path.basename(folder)
29+
with tarfile.open(f"{folder}.tar.gz", 'w:gz') as archive:
30+
archive.add(folder, arcname=base_folder_name)
31+
return f"{folder}.tar.gz"
2132

2233
class FineTuner:
2334
def __init__(self, conf):
2435
self.conf = conf
2536
os.makedirs(self.conf.tmp_dir, exist_ok=True)
2637

27-
def temp_file(self, name):
28-
return os.path.join(self.conf.tmp_dir, name)
38+
def temp_file(self, name, wipe=False):
39+
ret = os.path.join(self.conf.tmp_dir, name)
40+
if wipe:
41+
shutil.rmtree(ret, ignore_errors=True)
42+
return ret
2943

3044
def massage_line(self, ln, job):
45+
# toss our role for now, for some reason it didn't work
46+
# todo: check for role support in template
47+
if "mistral" in job["model"].lower():
48+
j = json.loads(ln)
49+
j["messages"] = [m for m in j["messages"] if m["role"] != "system"]
50+
ln = json.dumps(j) + "\n"
3151
return ln
3252

3353
def massage_fine_tune(self, file, job):
3454
cnt = 0
55+
tc = 0
56+
ec = 0
3557
training_split_pct = job.get("hyperparameters", {}).get("training_split", 0.8)
3658

3759
train_file = file + ".train"
@@ -41,42 +63,70 @@ def massage_fine_tune(self, file, job):
4163
with open(eval_file, "w") as ef:
4264
with open(file, "r") as inp:
4365
ln = inp.readline(MAX_CONTEXT)
44-
ln = self.massage_line(ln, job)
4566
while ln:
67+
ln = self.massage_line(ln, job)
4668
cnt += 1
47-
if random.random() > training_split_pct:
69+
if ec and (random.random() > training_split_pct or tc <= ec):
70+
tc += 1
4871
tf.write(ln)
4972
else:
73+
ec += 1
5074
ef.write(ln)
5175
ln = inp.readline(MAX_CONTEXT)
52-
ln = self.massage_line(ln, job)
5376
return train_file, eval_file
5477

5578
async def fine_tune(self, job):
5679
log.info("fine tuning: %s", job)
5780

58-
yield {"status": "downloading_data"}
81+
yield {"status": "download_data"}
5982

60-
base_model = job["model"]
6183
training_url = job["training_file"]
6284
training_file = await self.download_file(training_url)
63-
85+
job["training_file"] = training_file
86+
87+
q = asyncio.Queue()
88+
89+
loop = asyncio.get_running_loop()
90+
91+
t = threading.Thread(target=lambda: self._fine_tune(job, lambda res: loop.call_soon_threadsafe(q.put_nowait, res)), daemon=True)
92+
93+
t.start()
94+
while True:
95+
res = await q.get()
96+
if res is None:
97+
break
98+
yield res
99+
log.info("DONE")
100+
t.join()
101+
102+
def _fine_tune(self, job, cb):
103+
try:
104+
self._unsafe_fine_tune(job, cb)
105+
except Exception as ex:
106+
log.exception("error in fine tune")
107+
cb({"status": "error", "detail": repr(ex)})
108+
finally:
109+
cb(None)
110+
111+
def _unsafe_fine_tune(self, job, cb):
112+
training_file = job["training_file"]
64113
train_file, eval_file = self.massage_fine_tune(training_file, job)
65114

66-
train_dataset = load_jsonlines(open(train_file))
67-
eval_dataset = load_jsonlines(open(eval_file))
68-
69-
# todo: use user's model request
115+
base_model = job["model"]
116+
datasets = load_dataset("json", data_files={"train": train_file, "eval": eval_file})
117+
train_dataset = datasets["train"]
118+
eval_dataset = datasets["eval"]
70119

71-
base_model_id = "mistralai/Mistral-7B-v0.1"
120+
base_model_id = base_model.split(":")[0]
72121

73122
# todo: use hyperparams and Q_ filter, if present, for this
74123

75124
hp = job.get("hyperparameters", {})
76125

77126
args = {}
78127

79-
yield {"status": "loading_model"}
128+
log.info("load model")
129+
cb({"status": "load_model"})
80130

81131
args.update(dict(
82132
load_in_4bit=True,
@@ -89,25 +139,23 @@ async def fine_tune(self, job):
89139

90140
# todo: ideally we use llama cpp, but the cuda support for finetune isn't there
91141

92-
model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")
93-
94142
tokenizer = AutoTokenizer.from_pretrained(
95143
base_model_id,
96144
padding_side="left",
97145
add_eos_token=True,
98146
add_bos_token=True,
99147
)
100148

101-
train_dataset = tokenizer.apply_chat_template(train_dataset)
102-
eval_dataset = tokenizer.apply_chat_template(eval_dataset)
103-
149+
# sadly, does not take generators, just loads everything in ram
104150
tokenizer.pad_token = tokenizer.eos_token
105-
106-
max_length = 512
107-
151+
# todo: derive from model params
152+
max_length = 4096
108153
def generate_and_tokenize_prompt(prompt):
154+
# all input is openai formatted, and we clean it up above if needed
155+
pr = prompt["messages"]
156+
tmpl = tokenizer.apply_chat_template(pr, tokenize=False)
109157
result = tokenizer(
110-
prompt,
158+
tmpl,
111159
truncation=True,
112160
max_length=max_length,
113161
padding="max_length",
@@ -118,7 +166,9 @@ def generate_and_tokenize_prompt(prompt):
118166
tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
119167
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)
120168

121-
from peft import prepare_model_for_kbit_training
169+
model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto", resume_download=True)
170+
171+
from peft import prepare_model_for_kbit_training, PeftModel
122172

123173
model.gradient_checkpointing_enable()
124174
model = prepare_model_for_kbit_training(model)
@@ -127,7 +177,7 @@ def generate_and_tokenize_prompt(prompt):
127177

128178
config = LoraConfig(
129179
r=32,
130-
lora_alpha=64,
180+
lora_alpha=hp.get("lora_alpha", 64),
131181
target_modules=[
132182
"q_proj",
133183
"k_proj",
@@ -139,7 +189,7 @@ def generate_and_tokenize_prompt(prompt):
139189
"lm_head",
140190
],
141191
bias="none",
142-
lora_dropout=0.05, # Conventional
192+
lora_dropout=hp.get("lora_dropout", 0.05), # Conventional
143193
task_type="CAUSAL_LM",
144194
)
145195

@@ -162,41 +212,114 @@ def generate_and_tokenize_prompt(prompt):
162212
model.model_parallel = True
163213

164214
project = "journal-finetune"
165-
base_model_name = "mistral"
166-
run_name = base_model_name + "-" + project
215+
base_model_name = base_model_id.split("/")[-1]
216+
run_name = base_model_name + "-" + project + "-" + os.urandom(16).hex()
167217
output_dir = "./" + run_name
168218

169219
tokenizer.pad_token = tokenizer.eos_token
170220

221+
class EarlyStoppingCallback(TrainerCallback):
222+
def on_log(self, args, state, control, logs=None, **kwargs):
223+
cb({"status": "log", "logs": logs})
224+
eval_loss = logs.get("eval_loss", None)
225+
if eval_loss is not None and eval_loss <= hp.get("stop_eval_loss", 0.05):
226+
print("Early stopping criterion reached!")
227+
control.should_training_stop = True
228+
229+
def on_save(self, args, state, control, **kwargs):
230+
checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
231+
log.info(f"checkpoint {checkpoint_dir}")
232+
cb({"status": "checkpoint"})
233+
234+
171235
trainer = transformers.Trainer(
172236
model=model,
173237
train_dataset=tokenized_train_dataset,
174238
eval_dataset=tokenized_val_dataset,
239+
callbacks=[EarlyStoppingCallback()],
175240
args=transformers.TrainingArguments(
176241
output_dir=output_dir,
177242
warmup_steps=1,
178-
per_device_train_batch_size=2,
179-
gradient_accumulation_steps=1,
180-
max_steps=500,
181-
learning_rate=2.5e-5, # Want a small lr for finetuning
243+
per_device_train_batch_size=hp.get("batch_size", 4),
244+
gradient_accumulation_steps=hp.get("accumulation_steps", 4),
245+
max_steps=hp.get("max_steps", -1),
246+
num_train_epochs=hp.get("n_epochs", 3), # use openai terminology here
247+
learning_rate=hp.get("learning_rate_multiplier", 2.5e-5), # Want a small lr for finetuning
182248
bf16=True,
183249
optim="paged_adamw_8bit",
184250
logging_steps=25, # When to start reporting loss
185251
logging_dir="./logs", # Directory for storing logs
186252
save_strategy="steps", # Save the model checkpoint every logging step
187253
save_steps=25, # Save checkpoints
254+
save_total_limit=5, # Save checkpoints
255+
load_best_model_at_end=True,
188256
evaluation_strategy="steps", # Evaluate the model every logging step
189-
eval_steps=25, # Evaluate and save checkpoints every 50 steps
257+
eval_steps=25, # Evaluate and save checkpoints every 25 steps
190258
do_eval=True, # Perform evaluation at the end of training
191259
),
192260
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
193261
)
194262

195-
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
263+
log.info("start train")
264+
cb({"status": "start_train"})
265+
model.config.use_cache = False # silence the warnings
266+
196267
trainer.train()
197268

198-
res = {"status": "done", "checkpoint": str(base64.b64encode(b"checkpoint"))}
199-
yield res
269+
tmp = self.temp_file(run_name, wipe=True)
270+
tokenizer.save_pretrained(tmp)
271+
272+
self.return_final(run_name, model, cb)
273+
274+
def return_final(self, run_name, model, cb):
275+
log.info("return final")
276+
277+
tmp = self.temp_file(run_name)
278+
279+
# send up lora
280+
model.save_pretrained(tmp, safe_serialization=True)
281+
gz = gzip(tmp)
282+
shutil.rmtree(tmp)
283+
with open(gz, "rb") as fil:
284+
while True:
285+
dat = fil.read(100000)
286+
if not dat:
287+
break;
288+
res = {"status": "lora", "chunk": str(base64.b64encode(dat))}
289+
cb(res)
290+
291+
log.info("merge weights")
292+
293+
# merge weights
294+
295+
# reload as float16 for merge
296+
del model
297+
gc.collect()
298+
299+
model = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16, local_files_only=True, device_map="auto"), tmp)
300+
model = model.merge_and_unload()
301+
302+
gc.collect()
303+
model.save_pretrained(tmp)
304+
305+
# convert to gguf for fast inference
306+
log.info("ggml convert")
307+
gguf_main([tmp])
308+
gg = tmp + "/ggml-model-f16.gguf"
309+
with open(gg, "rb") as fil:
310+
while True:
311+
dat = fil.read(100000)
312+
if not dat:
313+
break;
314+
res = {"status": "gguf", "chunk": str(base64.b64encode(dat))}
315+
cb(res)
316+
317+
shutil.rmtree(tmp)
318+
shutil.rmtree(output_dir)
319+
320+
res = {"status": "done"}
321+
log.info("done train")
322+
cb(res)
200323

201324
async def download_file(self, training_url: str) -> str:
202325
output_file = self.temp_file(hashlib.md5(training_url.encode()).hexdigest())

ai_worker/jsonlines.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44
def load_jsonlines(fin):
55
while True:
66
lin = fin.readline()
7+
if not lin:
8+
return
79
yield json.loads(lin)
810

build-bin.sh

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ if [ -z "$cmake" -o -z "$gpu" ]; then
99
exit 1
1010
fi
1111

12+
with_torch=""
13+
if [ "$gpu" == "cuda-torch" ]; then
14+
with_torch="--with torch"
15+
fi
16+
17+
1218
set -o xtrace
1319

1420
python -mvenv "build-$gpu"
@@ -22,7 +28,7 @@ pip uninstall -y llama-cpp-python
2228
rm -f ~/AppData/Local/pypoetry/Cache/artifacts/*/*/*/*/llama*
2329
rm -f ~/.cache/pypoetry/artifacts/*/*/*/*/llama*
2430

25-
CMAKE_ARGS="$cmake" FORCE_CMAKE=1 poetry install
31+
CMAKE_ARGS="$cmake" FORCE_CMAKE=1 poetry install $with_torch
2632

2733
python build-version.py
2834

build-linux.sh

+2
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ set -o xtrace
66

77
./build-bin.sh opencl linux-64 "-DLLAMA_CLBLAST=ON"
88

9+
./build-bin.sh cuda-torch linux-64 "-DLLAMA_CLBLAST=ON"
10+
911
./upload.sh

0 commit comments

Comments
 (0)