|
3 | 3 | import random
|
4 | 4 | import time
|
5 | 5 |
|
| 6 | +import lightning as L |
6 | 7 | import numpy as np
|
7 | 8 | import torch
|
8 | 9 | import torch.nn as nn
|
9 | 10 | import torch.optim as optim
|
| 11 | +from lightning.callbacks import ModelCheckpoint |
| 12 | +from lightning.loggers import TensorBoardLogger |
10 | 13 | from sklearn.metrics import f1_score
|
11 | 14 | from torch.utils.data import DataLoader
|
| 15 | +from torch.utils.tensorboard import SummaryWriter |
12 | 16 |
|
13 | 17 | from hungarian_net.dataset import HungarianDataset
|
14 | 18 | from hungarian_net.models import HNetGRU
|
15 | 19 |
|
16 | 20 |
|
| 21 | +class HNetGRULightning(L.LightningModule): |
| 22 | + def __init__(self, max_len, sample_range_used, class_imbalance): |
| 23 | + super().__init__() |
| 24 | + self.model = HNetGRU(max_len=max_len) |
| 25 | + self.criterion1 = nn.BCEWithLogitsLoss(reduction="sum") |
| 26 | + self.criterion2 = nn.BCEWithLogitsLoss(reduction="sum") |
| 27 | + self.criterion3 = nn.BCEWithLogitsLoss(reduction="sum") |
| 28 | + self.criterion_wts = [1.0, 1.0, 1.0] |
| 29 | + self.sample_range_used = sample_range_used |
| 30 | + self.class_imbalance = class_imbalance |
| 31 | + |
| 32 | + def forward(self, x): |
| 33 | + return self.model(x) |
| 34 | + |
| 35 | + def training_step(self, batch, batch_idx): |
| 36 | + data, target = batch |
| 37 | + output1, output2, output3 = self(data) |
| 38 | + l1 = self.criterion1(output1, target[0]) |
| 39 | + l2 = self.criterion2(output2, target[1]) |
| 40 | + l3 = self.criterion3(output3, target[2]) |
| 41 | + loss = sum(w * l for w, l in zip(self.criterion_wts, [l1, l2, l3])) |
| 42 | + self.log("train_loss", loss) |
| 43 | + return loss |
| 44 | + |
| 45 | + def validation_step(self, batch, batch_idx): |
| 46 | + data, target = batch |
| 47 | + output1, output2, output3 = self(data) |
| 48 | + l1 = self.criterion1(output1, target[0]) |
| 49 | + l2 = self.criterion2(output2, target[1]) |
| 50 | + l3 = self.criterion3(output3, target[2]) |
| 51 | + loss = sum(w * l for w, l in zip(self.criterion_wts, [l1, l2, l3])) |
| 52 | + self.log("val_loss", loss) |
| 53 | + # Calculate F1 Score or other metrics here |
| 54 | + return loss |
| 55 | + |
| 56 | + def configure_optimizers(self): |
| 57 | + return optim.Adam(self.parameters()) |
| 58 | + |
| 59 | + |
| 60 | +# @hydra.main( |
| 61 | +# config_path="configs", |
| 62 | +# config_name="run.yaml", |
| 63 | +# version_base="1.3", |
| 64 | +# ) |
17 | 65 | def main(
|
18 | 66 | batch_size=256,
|
19 | 67 | nb_epochs=1000,
|
@@ -99,6 +147,9 @@ def main(
|
99 | 147 | The trained HNetGRU model with the best validation F1 score.
|
100 | 148 | """
|
101 | 149 |
|
| 150 | + # TODO: Réécriture/factorisation du code sur le modèle de VibraVox de Julien HAURET |
| 151 | + # TODO: leverager TensorBoard, Hydra, Pytorch Lightning, RayTune, Docker |
| 152 | + |
102 | 153 | set_seed()
|
103 | 154 |
|
104 | 155 | # Check wether to run on cpu or gpu
|
@@ -129,158 +180,23 @@ def main(
|
129 | 180 | drop_last=True,
|
130 | 181 | )
|
131 | 182 |
|
132 |
| - # load Hnet model and loss functions |
133 |
| - model = HNetGRU(max_len=max_len).to(device) |
134 |
| - optimizer = optim.Adam(model.parameters()) |
135 |
| - |
136 |
| - criterion1 = torch.nn.BCEWithLogitsLoss(reduction="sum") |
137 |
| - criterion2 = torch.nn.BCEWithLogitsLoss(reduction="sum") |
138 |
| - criterion3 = torch.nn.BCEWithLogitsLoss(reduction="sum") |
139 |
| - criterion_wts = [1.0, 1.0, 1.0] |
140 |
| - |
141 |
| - # Start training |
142 |
| - best_f = -1 |
143 |
| - best_epoch = -1 |
144 |
| - for epoch in range(1, nb_epochs + 1): |
145 |
| - train_start = time.time() |
146 |
| - # TRAINING |
147 |
| - model.train() |
148 |
| - train_loss, train_l1, train_l2, train_l3 = 0, 0, 0, 0 |
149 |
| - for batch_idx, (data, target) in enumerate(train_loader): |
150 |
| - data = data.to(device).float() |
151 |
| - target1 = target[0].to(device).float() |
152 |
| - target2 = target[1].to(device).float() |
153 |
| - target3 = target[2].to(device).float() |
154 |
| - |
155 |
| - optimizer.zero_grad() |
156 |
| - |
157 |
| - output1, output2, output3 = model(data) |
158 |
| - |
159 |
| - l1 = criterion1(output1, target1) |
160 |
| - l2 = criterion2(output2, target2) |
161 |
| - l3 = criterion3(output3, target3) |
162 |
| - loss = criterion_wts[0] * l1 + criterion_wts[1] * l2 + criterion_wts[2] * l3 |
163 |
| - |
164 |
| - loss.backward() |
165 |
| - optimizer.step() |
166 |
| - |
167 |
| - train_l1 += l1.item() |
168 |
| - train_l2 += l2.item() |
169 |
| - train_l3 += l3.item() |
170 |
| - train_loss += loss.item() |
171 |
| - |
172 |
| - train_l1 /= len(train_loader.dataset) |
173 |
| - train_l2 /= len(train_loader.dataset) |
174 |
| - train_l3 /= len(train_loader.dataset) |
175 |
| - train_loss /= len(train_loader.dataset) |
176 |
| - train_time = time.time() - train_start |
177 |
| - |
178 |
| - # TESTING |
179 |
| - test_start = time.time() |
180 |
| - model.eval() |
181 |
| - test_loss, test_l1, test_l2, test_l3 = 0, 0, 0, 0 |
182 |
| - test_f = 0 |
183 |
| - nb_test_batches = 0 |
184 |
| - true_positives, false_positives, false_negatives = 0, 0, 0 |
185 |
| - f1_score_unweighted = 0 |
186 |
| - with torch.no_grad(): |
187 |
| - for data, target in test_loader: |
188 |
| - data = data.to(device).float() |
189 |
| - target1 = target[0].to(device).float() |
190 |
| - target2 = target[1].to(device).float() |
191 |
| - target3 = target[2].to(device).float() |
192 |
| - |
193 |
| - output1, output2, output3 = model(data) |
194 |
| - l1 = criterion1(output1, target1) |
195 |
| - l2 = criterion2(output2, target2) |
196 |
| - l3 = criterion3(output3, target3) |
197 |
| - loss = ( |
198 |
| - criterion_wts[0] * l1 |
199 |
| - + criterion_wts[1] * l2 |
200 |
| - + criterion_wts[2] * l3 |
201 |
| - ) |
202 |
| - |
203 |
| - test_l1 += l1.item() |
204 |
| - test_l2 += l2.item() |
205 |
| - test_l3 += l3.item() |
206 |
| - test_loss += loss.item() # sum up batch loss |
207 |
| - |
208 |
| - f_pred = (torch.sigmoid(output1).cpu().numpy() > 0.5).reshape(-1) |
209 |
| - f_ref = target1.cpu().numpy().reshape(-1) |
210 |
| - test_f += f1_score( |
211 |
| - f_ref, |
212 |
| - f_pred, |
213 |
| - zero_division=1, |
214 |
| - average="weighted", |
215 |
| - sample_weight=f_score_weights, |
216 |
| - ) |
217 |
| - nb_test_batches += 1 |
218 |
| - |
219 |
| - true_positives += np.sum((f_pred == 1) & (f_ref == 1)) |
220 |
| - false_positives += np.sum((f_pred == 1) & (f_ref == 0)) |
221 |
| - false_negatives += np.sum((f_pred == 0) & (f_ref == 1)) |
222 |
| - |
223 |
| - f1_score_unweighted += ( |
224 |
| - 2 |
225 |
| - * true_positives |
226 |
| - / (2 * true_positives + false_positives + false_negatives) |
227 |
| - ) |
228 |
| - |
229 |
| - test_l1 /= len(test_loader.dataset) |
230 |
| - test_l2 /= len(test_loader.dataset) |
231 |
| - test_l3 /= len(test_loader.dataset) |
232 |
| - test_loss /= len(test_loader.dataset) |
233 |
| - test_f /= nb_test_batches |
234 |
| - test_time = time.time() - test_start |
235 |
| - weighted_accuracy = train_dataset.compute_weighted_accuracy( |
236 |
| - true_positives, false_positives |
237 |
| - ) |
238 |
| - |
239 |
| - f1_score_unweighted /= nb_test_batches |
240 |
| - |
241 |
| - # Early stopping |
242 |
| - if test_f > best_f: |
243 |
| - best_f = test_f |
244 |
| - best_epoch = epoch |
245 |
| - |
246 |
| - # Get current date |
247 |
| - current_date = datetime.datetime.now().strftime("%Y%m%d") |
248 |
| - |
249 |
| - # TODO: change model filename - leverage TensorBoard |
250 |
| - |
251 |
| - os.makedirs(f"models/{current_date}", exist_ok=True) |
252 |
| - |
253 |
| - # Human-readable filename |
254 |
| - out_filename = f"models/{current_date}/hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}.pt" |
255 |
| - |
256 |
| - torch.save(model.state_dict(), out_filename) |
257 |
| - |
258 |
| - model_to_return = model |
259 |
| - print( |
260 |
| - "Epoch: {}\t time: {:0.2f}/{:0.2f}\ttrain_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\ttest_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\tf_scr: {:.4f}\tbest_epoch: {}\tbest_f_scr: {:.4f}\ttrue_positives: {}\tfalse_positives: {}\tweighted_accuracy: {:.4f}".format( |
261 |
| - epoch, |
262 |
| - train_time, |
263 |
| - test_time, |
264 |
| - train_loss, |
265 |
| - train_l1, |
266 |
| - train_l2, |
267 |
| - train_l3, |
268 |
| - test_loss, |
269 |
| - test_l1, |
270 |
| - test_l2, |
271 |
| - test_l3, |
272 |
| - test_f, |
273 |
| - best_epoch, |
274 |
| - best_f, |
275 |
| - true_positives, |
276 |
| - false_positives, |
277 |
| - weighted_accuracy, |
278 |
| - ) |
279 |
| - ) |
280 |
| - print("F1 Score (unweighted) : {:.4f}".format(f1_score_unweighted)) |
281 |
| - print("Best epoch : {}\nBest F1 score : {}".format(best_epoch, best_f)) |
282 |
| - |
283 |
| - return model_to_return |
| 183 | + model = HNetGRULightning( |
| 184 | + max_len=max_len, |
| 185 | + sample_range_used=sample_range_used, |
| 186 | + class_imbalance=class_imbalance, |
| 187 | + ) |
| 188 | + |
| 189 | + logger = TensorBoardLogger("tb_logs", name="hnet_model") |
| 190 | + checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min") |
| 191 | + |
| 192 | + trainer = L.Trainer( |
| 193 | + max_epochs=nb_epochs, |
| 194 | + logger=logger, |
| 195 | + callbacks=[checkpoint_callback], |
| 196 | + # gpus=1 if use_cuda else 0 |
| 197 | + ) |
| 198 | + |
| 199 | + trainer.fit(model, train_loader, test_loader) |
284 | 200 |
|
285 | 201 |
|
286 | 202 | def set_seed(seed=42):
|
|
0 commit comments