-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
173 lines (142 loc) · 4.81 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import argparse
import json
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
from transformers import DataCollatorForLanguageModeling
from tqdm.auto import tqdm
import torch
from transformers import TrainingArguments, Trainer
import random
import wandb
random.seed(0)
def load_text(fname: str = "sanskrit_corpus.txt") -> list:
"""
load text and returns lines in a list after removing new line
:param fname: filename to load
:return: list of str
"""
with open(file=fname, mode="r") as fp:
lines = fp.read().split("\n")
return lines
def load_model_tokenizer(model_name: str, from_scratch=False) -> tuple:
"""
load model tokenizer and model from HF hub using model_name
:param model_name: model_name to load from HF hub
:param from_scratch: if True, don't use pretrained weights and load from config, otherwise loda pretrained weights
:return: (tokenizer, model)
"""
tokenizer = AutoTokenizer.from_pretrained(model_name, strip_accents=False)
if from_scratch:
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_config(config)
else:
model = AutoModelForMaskedLM.from_pretrained(model_name)
return (tokenizer, model)
def batch_encode(text: list, max_seq_len: int, batch_size=4096) -> tuple:
"""
memory efficient encoding of sentences
:param text: list of sentences
:param max_seq_len: maximum sequence to tokenize
:param batch_size: howmany sentences to tokenize at once
:return: (input_ids, attention_mask)
"""
encoded_sentences = []
for i in tqdm(range(0, len(text), batch_size)):
encoded_sent = tokenizer.batch_encode_plus(
text[i : i + batch_size],
max_length=max_seq_len,
add_special_tokens=True,
padding=True,
return_attention_mask=True,
pad_to_max_length=True,
truncation=True,
# return_tensors="pt",
)
encoded_sentences += encoded_sent
return (input_ids_train, attention_masks_train)
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
return item
def __len__(self):
return len(self.encodings["input_ids"])
# argparse for passing command line interface
parser = argparse.ArgumentParser(description="Training data")
parser.add_argument(
"--checkpoint",
type=str,
help="HF hub model name (default: 'ai4bharat/indic-bert')",
default="ai4bharat/indic-bert",
)
parser.add_argument(
"--from_scratch",
type=bool,
default=False,
help="initialize model weights as random",
)
parser.add_argument(
"--wandb",
type=bool,
default=False,
help="report to wandb",
)
parser.add_argument(
"--chkpt_dir",
type=str,
default=".",
help="model artidact directory",
)
args = parser.parse_args()
print(args.checkpoint)
# getting command-line arguments
model_name = args.checkpoint
from_scratch = args.from_scratch
report_to_wandb = "wandb" if args.wandb else "none"
chkpt_dir = args.chkpt_dir
# loading model based on arguments
model_name = "ai4bharat/indic-bert"
tokenizer, model = load_model_tokenizer(
model_name=model_name, from_scratch=from_scratch
)
# loading raw texts
text_train = load_text("sanskrit_corpus_train.txt")
text_eval = load_text("sanskrit_corpus_eval.txt")
# batch tokenizing the texts
tokenized_train = tokenizer(text_train, padding=True, truncation=True, max_length=128)
tokenized_eval = tokenizer(text_eval, padding=True, truncation=True, max_length=128)
# making a torch Dataset object for tokenized sentences
dataset_train = Dataset(tokenized_train)
dataset_eval = Dataset(tokenized_eval)
# Data collator does the masking of input tokens while training
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm_probability=0.15
)
# https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/trainer#transformers.TrainingArguments
# model training arguments can be changed
if report_to_wandb:
wandb.login()
training_args = TrainingArguments(
report_to=report_to_wandb,
output_dir=f"{chkpt_dir}/results_scratch_{str(from_scratch)}", # helps separating folders for two models
learning_rate=3e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=10,
weight_decay=0.01,
do_train=True,
do_eval=True,
logging_strategy="epoch",
evaluation_strategy="epoch",
save_strategy="epoch",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset_train,
eval_dataset=dataset_eval,
tokenizer=tokenizer,
data_collator=data_collator,
)
# training started
trainer.train()