Skip to content

Commit

Permalink
feat: integrate PyTorch Lightning into training pipeline
Browse files Browse the repository at this point in the history
- Refactored  to use LightningModule
- Updated training and validation steps with Lightning's hooks
- Added PyTorch Lightning configurations
- need to further integrate Lightning and Hydra by leveraging Julien' HAURET's VibraVox
  • Loading branch information
MaloOLIVIER committed Nov 29, 2024
1 parent d8fc140 commit c955062
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 157 deletions.
220 changes: 68 additions & 152 deletions hungarian_net/train_hnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,65 @@
import random
import time

import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from lightning.callbacks import ModelCheckpoint
from lightning.loggers import TensorBoardLogger
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from hungarian_net.dataset import HungarianDataset
from hungarian_net.models import HNetGRU


class HNetGRULightning(L.LightningModule):
def __init__(self, max_len, sample_range_used, class_imbalance):
super().__init__()
self.model = HNetGRU(max_len=max_len)
self.criterion1 = nn.BCEWithLogitsLoss(reduction="sum")
self.criterion2 = nn.BCEWithLogitsLoss(reduction="sum")
self.criterion3 = nn.BCEWithLogitsLoss(reduction="sum")
self.criterion_wts = [1.0, 1.0, 1.0]
self.sample_range_used = sample_range_used
self.class_imbalance = class_imbalance

def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
data, target = batch
output1, output2, output3 = self(data)
l1 = self.criterion1(output1, target[0])
l2 = self.criterion2(output2, target[1])
l3 = self.criterion3(output3, target[2])
loss = sum(w * l for w, l in zip(self.criterion_wts, [l1, l2, l3]))
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
data, target = batch
output1, output2, output3 = self(data)
l1 = self.criterion1(output1, target[0])
l2 = self.criterion2(output2, target[1])
l3 = self.criterion3(output3, target[2])
loss = sum(w * l for w, l in zip(self.criterion_wts, [l1, l2, l3]))
self.log("val_loss", loss)
# Calculate F1 Score or other metrics here
return loss

def configure_optimizers(self):
return optim.Adam(self.parameters())


# @hydra.main(
# config_path="configs",
# config_name="run.yaml",
# version_base="1.3",
# )
def main(
batch_size=256,
nb_epochs=1000,
Expand Down Expand Up @@ -99,6 +147,9 @@ def main(
The trained HNetGRU model with the best validation F1 score.
"""

# TODO: Réécriture/factorisation du code sur le modèle de VibraVox de Julien HAURET
# TODO: leverager TensorBoard, Hydra, Pytorch Lightning, RayTune, Docker

set_seed()

# Check wether to run on cpu or gpu
Expand Down Expand Up @@ -129,158 +180,23 @@ def main(
drop_last=True,
)

# load Hnet model and loss functions
model = HNetGRU(max_len=max_len).to(device)
optimizer = optim.Adam(model.parameters())

criterion1 = torch.nn.BCEWithLogitsLoss(reduction="sum")
criterion2 = torch.nn.BCEWithLogitsLoss(reduction="sum")
criterion3 = torch.nn.BCEWithLogitsLoss(reduction="sum")
criterion_wts = [1.0, 1.0, 1.0]

# Start training
best_f = -1
best_epoch = -1
for epoch in range(1, nb_epochs + 1):
train_start = time.time()
# TRAINING
model.train()
train_loss, train_l1, train_l2, train_l3 = 0, 0, 0, 0
for batch_idx, (data, target) in enumerate(train_loader):
data = data.to(device).float()
target1 = target[0].to(device).float()
target2 = target[1].to(device).float()
target3 = target[2].to(device).float()

optimizer.zero_grad()

output1, output2, output3 = model(data)

l1 = criterion1(output1, target1)
l2 = criterion2(output2, target2)
l3 = criterion3(output3, target3)
loss = criterion_wts[0] * l1 + criterion_wts[1] * l2 + criterion_wts[2] * l3

loss.backward()
optimizer.step()

train_l1 += l1.item()
train_l2 += l2.item()
train_l3 += l3.item()
train_loss += loss.item()

train_l1 /= len(train_loader.dataset)
train_l2 /= len(train_loader.dataset)
train_l3 /= len(train_loader.dataset)
train_loss /= len(train_loader.dataset)
train_time = time.time() - train_start

# TESTING
test_start = time.time()
model.eval()
test_loss, test_l1, test_l2, test_l3 = 0, 0, 0, 0
test_f = 0
nb_test_batches = 0
true_positives, false_positives, false_negatives = 0, 0, 0
f1_score_unweighted = 0
with torch.no_grad():
for data, target in test_loader:
data = data.to(device).float()
target1 = target[0].to(device).float()
target2 = target[1].to(device).float()
target3 = target[2].to(device).float()

output1, output2, output3 = model(data)
l1 = criterion1(output1, target1)
l2 = criterion2(output2, target2)
l3 = criterion3(output3, target3)
loss = (
criterion_wts[0] * l1
+ criterion_wts[1] * l2
+ criterion_wts[2] * l3
)

test_l1 += l1.item()
test_l2 += l2.item()
test_l3 += l3.item()
test_loss += loss.item() # sum up batch loss

f_pred = (torch.sigmoid(output1).cpu().numpy() > 0.5).reshape(-1)
f_ref = target1.cpu().numpy().reshape(-1)
test_f += f1_score(
f_ref,
f_pred,
zero_division=1,
average="weighted",
sample_weight=f_score_weights,
)
nb_test_batches += 1

true_positives += np.sum((f_pred == 1) & (f_ref == 1))
false_positives += np.sum((f_pred == 1) & (f_ref == 0))
false_negatives += np.sum((f_pred == 0) & (f_ref == 1))

f1_score_unweighted += (
2
* true_positives
/ (2 * true_positives + false_positives + false_negatives)
)

test_l1 /= len(test_loader.dataset)
test_l2 /= len(test_loader.dataset)
test_l3 /= len(test_loader.dataset)
test_loss /= len(test_loader.dataset)
test_f /= nb_test_batches
test_time = time.time() - test_start
weighted_accuracy = train_dataset.compute_weighted_accuracy(
true_positives, false_positives
)

f1_score_unweighted /= nb_test_batches

# Early stopping
if test_f > best_f:
best_f = test_f
best_epoch = epoch

# Get current date
current_date = datetime.datetime.now().strftime("%Y%m%d")

# TODO: change model filename - leverage TensorBoard

os.makedirs(f"models/{current_date}", exist_ok=True)

# Human-readable filename
out_filename = f"models/{current_date}/hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}.pt"

torch.save(model.state_dict(), out_filename)

model_to_return = model
print(
"Epoch: {}\t time: {:0.2f}/{:0.2f}\ttrain_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\ttest_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\tf_scr: {:.4f}\tbest_epoch: {}\tbest_f_scr: {:.4f}\ttrue_positives: {}\tfalse_positives: {}\tweighted_accuracy: {:.4f}".format(
epoch,
train_time,
test_time,
train_loss,
train_l1,
train_l2,
train_l3,
test_loss,
test_l1,
test_l2,
test_l3,
test_f,
best_epoch,
best_f,
true_positives,
false_positives,
weighted_accuracy,
)
)
print("F1 Score (unweighted) : {:.4f}".format(f1_score_unweighted))
print("Best epoch : {}\nBest F1 score : {}".format(best_epoch, best_f))

return model_to_return
model = HNetGRULightning(
max_len=max_len,
sample_range_used=sample_range_used,
class_imbalance=class_imbalance,
)

logger = TensorBoardLogger("tb_logs", name="hnet_model")
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")

trainer = L.Trainer(
max_epochs=nb_epochs,
logger=logger,
callbacks=[checkpoint_callback],
# gpus=1 if use_cuda else 0
)

trainer.fit(model, train_loader, test_loader)


def set_seed(seed=42):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ scikit-learn>=1.5.2
scipy>=1.14.1
torch>=2.5.1
torchaudio>=2.5.1
tensorboard>=2.18.0
tensorboard>=2.18.0
lightning>=2.4.0
2 changes: 0 additions & 2 deletions tests/nonregression_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# tests/nonregression_tests/conftest.py

import pytest
from pytest_mock import mocker

from hungarian_net.models import HNetGRU


Expand Down
5 changes: 3 additions & 2 deletions tests/scenarios_tests/model/test_scenarios_train_hnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# tests/scenarios_tests/model/test_train_hnet.py

import re

import pytest

from hungarian_net.train_hnet import main


Expand Down Expand Up @@ -62,6 +60,9 @@ def test_train_model_under_various_distributions(
else:
sample_range_used = None # Default values

# Mock nb_epochs to be 1 regardless of the input
nb_epochs = 1

main(
batch_size=batch_size,
nb_epochs=nb_epochs,
Expand Down

0 comments on commit c955062

Please sign in to comment.