Skip to content

Commit c955062

Browse files
committed
feat: integrate PyTorch Lightning into training pipeline
- Refactored to use LightningModule - Updated training and validation steps with Lightning's hooks - Added PyTorch Lightning configurations - need to further integrate Lightning and Hydra by leveraging Julien' HAURET's VibraVox
1 parent d8fc140 commit c955062

File tree

4 files changed

+73
-157
lines changed

4 files changed

+73
-157
lines changed

hungarian_net/train_hnet.py

Lines changed: 68 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,65 @@
33
import random
44
import time
55

6+
import lightning as L
67
import numpy as np
78
import torch
89
import torch.nn as nn
910
import torch.optim as optim
11+
from lightning.callbacks import ModelCheckpoint
12+
from lightning.loggers import TensorBoardLogger
1013
from sklearn.metrics import f1_score
1114
from torch.utils.data import DataLoader
15+
from torch.utils.tensorboard import SummaryWriter
1216

1317
from hungarian_net.dataset import HungarianDataset
1418
from hungarian_net.models import HNetGRU
1519

1620

21+
class HNetGRULightning(L.LightningModule):
22+
def __init__(self, max_len, sample_range_used, class_imbalance):
23+
super().__init__()
24+
self.model = HNetGRU(max_len=max_len)
25+
self.criterion1 = nn.BCEWithLogitsLoss(reduction="sum")
26+
self.criterion2 = nn.BCEWithLogitsLoss(reduction="sum")
27+
self.criterion3 = nn.BCEWithLogitsLoss(reduction="sum")
28+
self.criterion_wts = [1.0, 1.0, 1.0]
29+
self.sample_range_used = sample_range_used
30+
self.class_imbalance = class_imbalance
31+
32+
def forward(self, x):
33+
return self.model(x)
34+
35+
def training_step(self, batch, batch_idx):
36+
data, target = batch
37+
output1, output2, output3 = self(data)
38+
l1 = self.criterion1(output1, target[0])
39+
l2 = self.criterion2(output2, target[1])
40+
l3 = self.criterion3(output3, target[2])
41+
loss = sum(w * l for w, l in zip(self.criterion_wts, [l1, l2, l3]))
42+
self.log("train_loss", loss)
43+
return loss
44+
45+
def validation_step(self, batch, batch_idx):
46+
data, target = batch
47+
output1, output2, output3 = self(data)
48+
l1 = self.criterion1(output1, target[0])
49+
l2 = self.criterion2(output2, target[1])
50+
l3 = self.criterion3(output3, target[2])
51+
loss = sum(w * l for w, l in zip(self.criterion_wts, [l1, l2, l3]))
52+
self.log("val_loss", loss)
53+
# Calculate F1 Score or other metrics here
54+
return loss
55+
56+
def configure_optimizers(self):
57+
return optim.Adam(self.parameters())
58+
59+
60+
# @hydra.main(
61+
# config_path="configs",
62+
# config_name="run.yaml",
63+
# version_base="1.3",
64+
# )
1765
def main(
1866
batch_size=256,
1967
nb_epochs=1000,
@@ -99,6 +147,9 @@ def main(
99147
The trained HNetGRU model with the best validation F1 score.
100148
"""
101149

150+
# TODO: Réécriture/factorisation du code sur le modèle de VibraVox de Julien HAURET
151+
# TODO: leverager TensorBoard, Hydra, Pytorch Lightning, RayTune, Docker
152+
102153
set_seed()
103154

104155
# Check wether to run on cpu or gpu
@@ -129,158 +180,23 @@ def main(
129180
drop_last=True,
130181
)
131182

132-
# load Hnet model and loss functions
133-
model = HNetGRU(max_len=max_len).to(device)
134-
optimizer = optim.Adam(model.parameters())
135-
136-
criterion1 = torch.nn.BCEWithLogitsLoss(reduction="sum")
137-
criterion2 = torch.nn.BCEWithLogitsLoss(reduction="sum")
138-
criterion3 = torch.nn.BCEWithLogitsLoss(reduction="sum")
139-
criterion_wts = [1.0, 1.0, 1.0]
140-
141-
# Start training
142-
best_f = -1
143-
best_epoch = -1
144-
for epoch in range(1, nb_epochs + 1):
145-
train_start = time.time()
146-
# TRAINING
147-
model.train()
148-
train_loss, train_l1, train_l2, train_l3 = 0, 0, 0, 0
149-
for batch_idx, (data, target) in enumerate(train_loader):
150-
data = data.to(device).float()
151-
target1 = target[0].to(device).float()
152-
target2 = target[1].to(device).float()
153-
target3 = target[2].to(device).float()
154-
155-
optimizer.zero_grad()
156-
157-
output1, output2, output3 = model(data)
158-
159-
l1 = criterion1(output1, target1)
160-
l2 = criterion2(output2, target2)
161-
l3 = criterion3(output3, target3)
162-
loss = criterion_wts[0] * l1 + criterion_wts[1] * l2 + criterion_wts[2] * l3
163-
164-
loss.backward()
165-
optimizer.step()
166-
167-
train_l1 += l1.item()
168-
train_l2 += l2.item()
169-
train_l3 += l3.item()
170-
train_loss += loss.item()
171-
172-
train_l1 /= len(train_loader.dataset)
173-
train_l2 /= len(train_loader.dataset)
174-
train_l3 /= len(train_loader.dataset)
175-
train_loss /= len(train_loader.dataset)
176-
train_time = time.time() - train_start
177-
178-
# TESTING
179-
test_start = time.time()
180-
model.eval()
181-
test_loss, test_l1, test_l2, test_l3 = 0, 0, 0, 0
182-
test_f = 0
183-
nb_test_batches = 0
184-
true_positives, false_positives, false_negatives = 0, 0, 0
185-
f1_score_unweighted = 0
186-
with torch.no_grad():
187-
for data, target in test_loader:
188-
data = data.to(device).float()
189-
target1 = target[0].to(device).float()
190-
target2 = target[1].to(device).float()
191-
target3 = target[2].to(device).float()
192-
193-
output1, output2, output3 = model(data)
194-
l1 = criterion1(output1, target1)
195-
l2 = criterion2(output2, target2)
196-
l3 = criterion3(output3, target3)
197-
loss = (
198-
criterion_wts[0] * l1
199-
+ criterion_wts[1] * l2
200-
+ criterion_wts[2] * l3
201-
)
202-
203-
test_l1 += l1.item()
204-
test_l2 += l2.item()
205-
test_l3 += l3.item()
206-
test_loss += loss.item() # sum up batch loss
207-
208-
f_pred = (torch.sigmoid(output1).cpu().numpy() > 0.5).reshape(-1)
209-
f_ref = target1.cpu().numpy().reshape(-1)
210-
test_f += f1_score(
211-
f_ref,
212-
f_pred,
213-
zero_division=1,
214-
average="weighted",
215-
sample_weight=f_score_weights,
216-
)
217-
nb_test_batches += 1
218-
219-
true_positives += np.sum((f_pred == 1) & (f_ref == 1))
220-
false_positives += np.sum((f_pred == 1) & (f_ref == 0))
221-
false_negatives += np.sum((f_pred == 0) & (f_ref == 1))
222-
223-
f1_score_unweighted += (
224-
2
225-
* true_positives
226-
/ (2 * true_positives + false_positives + false_negatives)
227-
)
228-
229-
test_l1 /= len(test_loader.dataset)
230-
test_l2 /= len(test_loader.dataset)
231-
test_l3 /= len(test_loader.dataset)
232-
test_loss /= len(test_loader.dataset)
233-
test_f /= nb_test_batches
234-
test_time = time.time() - test_start
235-
weighted_accuracy = train_dataset.compute_weighted_accuracy(
236-
true_positives, false_positives
237-
)
238-
239-
f1_score_unweighted /= nb_test_batches
240-
241-
# Early stopping
242-
if test_f > best_f:
243-
best_f = test_f
244-
best_epoch = epoch
245-
246-
# Get current date
247-
current_date = datetime.datetime.now().strftime("%Y%m%d")
248-
249-
# TODO: change model filename - leverage TensorBoard
250-
251-
os.makedirs(f"models/{current_date}", exist_ok=True)
252-
253-
# Human-readable filename
254-
out_filename = f"models/{current_date}/hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}.pt"
255-
256-
torch.save(model.state_dict(), out_filename)
257-
258-
model_to_return = model
259-
print(
260-
"Epoch: {}\t time: {:0.2f}/{:0.2f}\ttrain_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\ttest_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\tf_scr: {:.4f}\tbest_epoch: {}\tbest_f_scr: {:.4f}\ttrue_positives: {}\tfalse_positives: {}\tweighted_accuracy: {:.4f}".format(
261-
epoch,
262-
train_time,
263-
test_time,
264-
train_loss,
265-
train_l1,
266-
train_l2,
267-
train_l3,
268-
test_loss,
269-
test_l1,
270-
test_l2,
271-
test_l3,
272-
test_f,
273-
best_epoch,
274-
best_f,
275-
true_positives,
276-
false_positives,
277-
weighted_accuracy,
278-
)
279-
)
280-
print("F1 Score (unweighted) : {:.4f}".format(f1_score_unweighted))
281-
print("Best epoch : {}\nBest F1 score : {}".format(best_epoch, best_f))
282-
283-
return model_to_return
183+
model = HNetGRULightning(
184+
max_len=max_len,
185+
sample_range_used=sample_range_used,
186+
class_imbalance=class_imbalance,
187+
)
188+
189+
logger = TensorBoardLogger("tb_logs", name="hnet_model")
190+
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")
191+
192+
trainer = L.Trainer(
193+
max_epochs=nb_epochs,
194+
logger=logger,
195+
callbacks=[checkpoint_callback],
196+
# gpus=1 if use_cuda else 0
197+
)
198+
199+
trainer.fit(model, train_loader, test_loader)
284200

285201

286202
def set_seed(seed=42):

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ scikit-learn>=1.5.2
44
scipy>=1.14.1
55
torch>=2.5.1
66
torchaudio>=2.5.1
7-
tensorboard>=2.18.0
7+
tensorboard>=2.18.0
8+
lightning>=2.4.0

tests/nonregression_tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# tests/nonregression_tests/conftest.py
22

33
import pytest
4-
from pytest_mock import mocker
5-
64
from hungarian_net.models import HNetGRU
75

86

tests/scenarios_tests/model/test_scenarios_train_hnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# tests/scenarios_tests/model/test_train_hnet.py
22

33
import re
4-
54
import pytest
6-
75
from hungarian_net.train_hnet import main
86

97

@@ -62,6 +60,9 @@ def test_train_model_under_various_distributions(
6260
else:
6361
sample_range_used = None # Default values
6462

63+
# Mock nb_epochs to be 1 regardless of the input
64+
nb_epochs = 1
65+
6566
main(
6667
batch_size=batch_size,
6768
nb_epochs=nb_epochs,

0 commit comments

Comments
 (0)