Skip to content

Commit

Permalink
Refactor dataset and model for PyTorch Lightning integration
Browse files Browse the repository at this point in the history
- Add HungarianDataModule using LightningDataModule
- Update HNetGRULightning to use torchmetrics for F1Score
- Implement TensorBoard logging for loss and F1-score
- Modify train_hnet.py to utilize the new DataModule and model classes
  • Loading branch information
MaloOLIVIER committed Dec 2, 2024
1 parent ce4a027 commit 1962dc2
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 72 deletions.
60 changes: 59 additions & 1 deletion hungarian_net/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from torch.utils.data import Dataset
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from hungarian_net.generate_hnet_training_data import load_obj

Expand Down Expand Up @@ -99,3 +100,60 @@ def compute_weighted_accuracy(self, n1star, n0star):
WA = (w1 * n1star + w0 * n0star) / (w1 * n1 + w0 * n0)

return WA


class HungarianDataModule(LightningDataModule):
"""
LightningDataModule for HungarianDataset.
Args:
train_filename (str): Filename for training data.
test_filename (str): Filename for testing data.
max_len (int, optional): Maximum number of Directions of Arrival (DOAs). Defaults to 2.
batch_size (int, optional): Batch size for data loaders. Defaults to 256.
num_workers (int, optional): Number of workers for data loaders. Defaults to 4.
"""

def __init__(
self, train_filename, test_filename, max_len=2, batch_size=256, num_workers=4
):
super().__init__()
self.train_filename = train_filename
self.test_filename = test_filename
self.max_len = max_len
self.batch_size = batch_size
self.num_workers = num_workers

# def transfer_batch_to_device(self, batch, device, dataloader_idx):

def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = HungarianDataset(
train=True, max_len=self.max_len, filename=self.train_filename
)
self.val_dataset = HungarianDataset(
train=False, max_len=self.max_len, filename=self.test_filename
)
if stage == "test" or stage is None:
self.test_dataset = HungarianDataset(
train=False, max_len=self.max_len, filename=self.test_filename
)

def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
)

def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)

def test_dataloader(self):
return DataLoader(
self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
181 changes: 181 additions & 0 deletions hungarian_net/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,186 @@
from functools import partial
from typing import Any, Dict

import lightning as L
import torch
import torch.nn as nn
import torchmetrics
from lightning.pytorch.utilities.types import STEP_OUTPUT
from sklearn.metrics import f1_score
from torch import optim
from torchmetrics import MetricCollection


class HNetGRULightning(L.LightningModule):
""" """

def __init__(
self,
device,
max_len: int = 2,
optimizer: partial[torch.optim.Optimizer] = partial(optim.Adam),
):
super().__init__()
self._device = device
self.model = HNetGRU(max_len=max_len).to(self._device)

self.criterion1 = nn.BCEWithLogitsLoss(reduction="sum")
self.criterion2 = nn.BCEWithLogitsLoss(reduction="sum")
self.criterion3 = nn.BCEWithLogitsLoss(reduction="sum")
self.criterion_wts = [1.0, 1.0, 1.0]

self.optimizer: torch.optim.Optimizer = optimizer(
params=self.model.parameters()
)

self.train_f1 = torchmetrics.F1Score(
task="multiclass",
num_classes=2,
average="weighted",
zero_division=1,
).to(self._device)
self.val_f1 = torchmetrics.F1Score(
task="multiclass",
num_classes=2,
average="weighted",
zero_division=1,
).to(self._device)
self.test_f1 = torchmetrics.F1Score(
task="multiclass",
num_classes=2,
average="weighted",
zero_division=1,
).to(self._device)

def common_step(self, batch, batch_idx):
data, target = batch
data = data.to(self._device).float()

# forward pass
output = self.model(data)
l1 = self.criterion1(output[0], target[0])
l2 = self.criterion2(output[1], target[1])
l3 = self.criterion3(output[2], target[2])

loss = (
self.criterion_wts[0] * l1
+ self.criterion_wts[1] * l2
+ self.criterion_wts[2] * l3
)

return loss, output, target

def training_step(self, batch, batch_idx):
"""
Lightning training step
Args:
batch (Dict[str, torch.Tensor]): Dict with keys "audio", "phonemes_ids", "phonemes_str"
"""

train_loss, output, target = self.common_step(batch, batch_idx)

preds = torch.sigmoid(output[0]) > 0.5

train_f1 = self.train_f1(preds, target[0])

# Log loss and F1-score
self.log("train_loss", train_loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train_f1", train_f1, on_step=False, on_epoch=True, prog_bar=False)

return train_loss

def validation_step(self, batch, batch_idx):
"""
Lightning validation step
Args:
batch (Dict[str, torch.Tensor]): Dict with keys "audio", "phonemes_ids", "phonemes_str"
"""

val_loss, output, target = self.common_step(batch, batch_idx)

preds = torch.sigmoid(output[0]) > 0.5
val_f1 = self.val_f1(preds, target[0])

# Log loss and F1-score
self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val_f1", val_f1, on_step=False, on_epoch=True, prog_bar=True)

return val_loss

def test_step(self, batch: Dict[str, torch.Tensor], batch_idx):
"""
Lightning test step
Args:
batch (Dict[str, torch.Tensor]): Dict with keys "audio", "phonemes_ids", "phonemes_str"
"""
test_loss, output, target = self.common_step(batch, batch_idx)

preds = torch.sigmoid(output[0]) > 0.5
test_f1 = self.test_f1(preds, target[0])

# Log loss and F1-score
self.log("test_loss", test_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("test_f1", test_f1, on_step=False, on_epoch=True, prog_bar=True)

return test_loss

return self.common_step(batch, batch_idx)

# def on_fit_start(self) -> None:
"""
Called at the beginning of the fit loop.
- Checks the consistency of the DataModule's parameters
"""
# self.check_datamodule_parameter()

# def on_test_start(self) -> None:
"""
Called at the beginning of the testing loop.
- Checks the consistency of the DataModule's parameters
"""

# self.check_datamodule_parameter()

def configure_optimizers(self):
"""
Method to configure optimizers and schedulers. Automatically called by Lightning's Trainer.
Returns:
List[torch.optimizer.Optimizer]
"""

return self.optimizer

# def common_logging(
# self, stage: str, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
# ) -> None:
"""
Common logging for training, validation and test steps.
Args:
stage(str): Stage of the training
outputs(STEP_OUTPUT): Output of the {train,validation,test}_step method
batch (Dict[str, torch.Tensor]): Dict with keys "audio", "phonemes_ids", "phonemes_str"
batch_idx(int): Index of the batch
"""

# Log loss
# self.log(f"{stage}/loss", outputs["loss"], sync_dist=True)

# Log metrics
# predicted_phonemes = self.get_phonemes_from_logits(outputs["logits"])
# target_phonemes = batch["phonemes_str"]
# metrics_to_log = self.metrics(predicted_phonemes, target_phonemes)
# metrics_to_log = {f"{stage}/{k}": v for k, v in metrics_to_log.items()}

# self.log_dict(dictionary=metrics_to_log, sync_dist=True, prog_bar=True)


class AttentionLayer(nn.Module):
Expand Down
Loading

0 comments on commit 1962dc2

Please sign in to comment.