Skip to content

Commit

Permalink
Ensure pytest and hydra-core are included correctly in requirements.t…
Browse files Browse the repository at this point in the history
…xt | removed docs from run.py
  • Loading branch information
MaloOLIVIER committed Dec 2, 2024
1 parent d6672cd commit 8e42266
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 81 deletions.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ scipy>=1.14.1
torch>=2.5.1
torchaudio>=2.5.1
tensorboard>=2.18.0
lightning>=2.4.0
lightning>=2.4.0
pytest>=8.3.4
hydra-core>=1.3.2
82 changes: 2 additions & 80 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader

from hungarian_net.lightning_datamodules.hungarian_datamodule import HungarianDataModule, HungarianDataset
from hungarian_net.lightning_modules.hnet_gru_lightning import HNetGRULightning
Expand All @@ -30,82 +27,7 @@ def main(
filename_train="data/reference/hung_data_train",
filename_test="data/reference/hung_data_test",
):
"""
Train the Hungarian Network (HNetGRU) model.
This function orchestrates the training process of the HNetGRU model, including data loading,
model initialization, training loop with validation, and saving the best-performing model.
Args:
batch_size (int, optional):
Number of samples per training batch. Defaults to 256.
nb_epochs (int, optional):
Total number of training epochs. Defaults to 1000.
max_len (int, optional):
Maximum number of Directions of Arrival (DOAs) the model can handle. Defaults to 2.
sample_range_used (List[int], optional):
List specifying the range of samples used for training. Defaults to [3000, 5000, 15000].
filename_train (str, optional):
Path to the training data file. Defaults to "data/reference/hung_data_train".
filename_test (str, optional):
Path to the testing data file. Defaults to "data/reference/hung_data_test".
Steps:
1. **Set Random Seed**:
- Ensures reproducibility by setting seeds for Python's `random`, NumPy, and PyTorch.
2. **Device Configuration**:
- Checks for GPU availability and sets the computation device accordingly (CUDA or CPU).
3. **Data Loading**:
- Loads the training dataset using `HungarianDataset` with specified parameters.
- Initializes a `DataLoader` for batching and shuffling training data.
- Calculates class imbalance to handle potential data skew.
- Loads the validation dataset similarly.
4. **Model and Optimizer Initialization**:
- Instantiates the `HNetGRU` model and moves it to the configured device.
- Sets up the optimizer (`Adam`) for training the model parameters.
5. **Loss Function Definition**:
- Defines three separate Binary Cross-Entropy with Logits Loss functions (`criterion1`, `criterion2`, `criterion3`).
- Assigns equal weights to each loss component.
6. **Training Loop**:
- Iterates over the specified number of epochs.
- **Training Phase**:
a. Sets the model to training mode.
b. Iterates over training batches:
- Performs forward pass.
- Computes individual losses.
- Aggregates losses with defined weights.
- Backpropagates and updates model weights.
c. Accumulates and averages training losses.
- **Validation Phase**:
a. Sets the model to evaluation mode.
b. Iterates over validation batches without gradient computation:
- Performs forward pass.
- Computes losses.
- Calculates F1 scores with weighted averaging.
c. Accumulates and averages validation losses and F1 scores.
- **Early Stopping**:
- Monitors validation F1 score to identify and save the best-performing model.
- Saves model weights with a timestamped filename for version tracking.
- **Metrics Logging**:
- Prints comprehensive metrics after each epoch, including losses, F1 scores, and accuracy.
- Tracks unweighted F1 scores separately for detailed analysis.
7. **Final Output**:
- Prints the best epoch and corresponding F1 score.
- Returns the best-performing model instance.
Returns:
HNetGRU:
The trained HNetGRU model with the best validation F1 score.
"""


# TODO: Réécriture/factorisation du code sur le modèle de VibraVox de Julien HAURET
# TODO: leverager TensorBoard, Hydra, Pytorch Lightning, RayTune, Docker
Expand Down Expand Up @@ -201,5 +123,5 @@ def setup_environment():


if __name__ == "__main__":
device = setup_environment()
setup_environment()
main()

0 comments on commit 8e42266

Please sign in to comment.