-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcad_fine_tuning_trainer.py
100 lines (80 loc) · 3.93 KB
/
cad_fine_tuning_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import transformers
import wandb
import datasets
import utils
def freeze_layers_lm(to_freeze, n_to_unfreeze, model):
if to_freeze:
for parameter in model.parameters():
parameter.requires_grad = False
for i, m in enumerate(model.transformer.h):
# Only un-freeze the last n transformer blocks
if i+1 > len(model.transformer.h) - n_to_unfreeze:
for parameter in m.parameters():
parameter.requires_grad = True
for parameter in model.transformer.ln_f.parameters():
parameter.requires_grad = True
for parameter in model.lm_head.parameters():
parameter.requires_grad = True
print(f"Freezed the first {len(model.transformer.h)-n_to_unfreeze} model's layers")
print(f"Only the last {n_to_unfreeze} model's layers will be trained!")
else:
print("All the model's layers will be trained!")
return model
def prepare_training(df_trainset,
df_valset,
tokenizer,
tokenize_in_batch) -> (datasets.Dataset, datasets.Dataset):
# convert dataset from pandas to Dataset
training_set = datasets.Dataset.from_pandas(df_trainset)
val_set = datasets.Dataset.from_pandas(df_valset)
# TOKENIZE datasets
tokenized_train = training_set.map(lambda examples: tokenizer(examples["wrapped_input"],
padding="max_length",
truncation=True), batched=tokenize_in_batch)
tokenized_train = tokenized_train.add_column("labels", tokenized_train['input_ids'])
tokenized_val = val_set.map(lambda examples: tokenizer(examples["wrapped_input"],
padding="max_length",
truncation=True), batched=tokenize_in_batch)
tokenized_val = tokenized_val.add_column("labels", tokenized_val['input_ids'])
return tokenized_train, tokenized_val
def train(out_dir, lm, trainset, valset, no_cuda, training_cfgs, project_name,
run_name=None, save_model=True, is_sweep=False):
with wandb.init(project=project_name, name=run_name):
if is_sweep:
# use wandb sweep config dict
for k in wandb.config.keys():
training_cfgs[k] = wandb.config[k]
lm = freeze_layers_lm(training_cfgs['FREEZE_LAYERS'], training_cfgs['UNFREEZE_LAST_N'], lm)
early_stopping = transformers.EarlyStoppingCallback(early_stopping_patience=training_cfgs['STOPPING_PATIENCE'])
training_args = transformers.TrainingArguments(
output_dir=out_dir,
overwrite_output_dir=True,
no_cuda=no_cuda,
num_train_epochs=training_cfgs['MAX_EPOCHS'],
per_device_train_batch_size=training_cfgs['TRAIN_BATCHSIZE'],
per_device_eval_batch_size=training_cfgs['EVAL_BATCHSIZE'],
gradient_accumulation_steps=training_cfgs['BATCH_UPDATE'],
do_eval=True,
evaluation_strategy=transformers.IntervalStrategy.EPOCH,
warmup_steps=training_cfgs['WARMUP_STEPS'],
learning_rate=training_cfgs['LR'],
adam_epsilon=training_cfgs['ADAM_EPS'],
weight_decay=training_cfgs['WEIGHT_DECAY'],
save_total_limit=1,
save_strategy=transformers.IntervalStrategy.EPOCH,
load_best_model_at_end=True,
metric_for_best_model='eval_loss',
fp16=training_cfgs['fp16'],
optim=training_cfgs['optim'],
)
trainer = transformers.Trainer(
model=lm,
args=training_args,
train_dataset=trainset,
eval_dataset=valset,
callbacks=[early_stopping]
)
utils.print_gpu_utilization()
trainer.train()
if save_model:
trainer.save_model()