Skip to content

Commit

Permalink
**refactor(run.py): Integrate Hydra for module instantiation and clea…
Browse files Browse the repository at this point in the history
…n up configurations**

- **Use Hydra to instantiate  and :**
  - Replaced manual instantiation with  and .

- **Remove hard-coded dataset paths:**
  - Eliminated hard-coded  and  to rely on Hydra configurations.

- **Update imports:**
  - Removed unused imports such as , , , , , and .
  - Added necessary imports like  from ,  from , and  from .

- **Clean up Trainer instantiation:**
  - Removed the  and related directory creation logic.
  - Updated  instantiation to exclude  and ensure proper callback management.

- **Enhance documentation:**
  - Improved docstrings for better clarity and understanding of the main function's purpose.

- **Miscellaneous:**
  - Added type hints for better code readability and maintenance.

**Benefits:**
- Enhances configurability and flexibility by leveraging Hydra's powerful configuration management.
- Simplifies the  script by removing hard-coded values and unused components.
- Improves code maintainability and readability through better documentation and type hinting.
  • Loading branch information
MaloOLIVIER committed Dec 3, 2024
1 parent 30da96d commit 1e75822
Showing 1 changed file with 27 additions and 63 deletions.
90 changes: 27 additions & 63 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,50 @@
import os
import random
import warnings
from typing import List

import hydra
import lightning as L
import numpy as np
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from hungarian_net.lightning_datamodules.hungarian_datamodule import HungarianDataModule, HungarianDataset
from hungarian_net.lightning_modules.hnet_gru_lightning import HNetGRULightning


@hydra.main(
config_path="configs",
config_name="run.yaml",
version_base="1.3",
)
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from torchmetrics import MetricCollection

@hydra.main(
config_path="configs",
config_name="run.yaml",
version_base="1.3",
)
def main(cfg: DictConfig):
""" batch_size=256,
nb_epochs=1000,
max_len=2,
sample_range_used=[3000, 5000, 15000],
filename_train="data/reference/hung_data_train",
filename_test="data/reference/hung_data_test", """
"""
Instantiate all necessary modules, train and test the model.

Args:
cfg (DictConfig): Hydra configuration object, passed in by the @hydra.main decorator
"""

# TODO: Réécriture/factorisation du code sur le modèle de VibraVox de Julien HAURET
# TODO: leverager TensorBoard, Hydra, Pytorch Lightning, RayTune, Docker

# Temporarly mock the dataloader
filename_train = "data/20241202/train/hung_data_train_DOA2_3000-5000-15000"
filename_test = "data/20241202/test/hung_data_test_DOA2_3000-5000-15000"

lightning_datamodule = HungarianDataModule(
train_filename=filename_train,
test_filename=filename_test,
max_len=max_len,
batch_size=batch_size,
num_workers=4,
)

# metrics: MetricCollection = MetricCollection(
# dict(hydra.utils.instantiate(cfg.metrics))
# )

use_cuda = torch.cuda.is_available()
lightning_module = HNetGRULightning(
metrics=None,
device=torch.device("cuda" if use_cuda else "cpu"),
max_len=max_len,

# Instantiate LightningDataModule
lightning_datamodule: L.LightningDataModule = hydra.utils.instantiate(
cfg.lightning_datamodule
)

# Instantiate LightningModule
lightning_module: LightningModule = hydra.utils.instantiate(
metrics: MetricCollection = MetricCollection(
dict(hydra.utils.instantiate(cfg.metrics))
)
lightning_module: L.LightningModule = hydra.utils.instantiate(
cfg.lightning_module,
metrics=metrics,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), #mock for now #TODO: hide device, supposed to be handled by lightning
)

# Get current date
current_date = datetime.datetime.now().strftime("%Y%m%d")

os.makedirs(f"checkpoints/{current_date}", exist_ok=True)

# Human-readable filename
dirpath = f"checkpoints/{current_date}/"
out_filename = f"hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}"

""" checkpoint_callback = ModelCheckpoint(
dirpath=dirpath,
filename=out_filename,
monitor="validation_loss",
save_top_k=1,
mode="min",
) """

# Instantiate Trainer
callbacks: List[Callback] = list(hydra.utils.instantiate(cfg.callbacks).values())
callbacks: List[L.Callback] = list(hydra.utils.instantiate(cfg.callbacks).values())
logger: Logger = hydra.utils.instantiate(cfg.logging.logger)
trainer: Trainer = hydra.utils.instantiate(
trainer: L.Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

Expand Down

0 comments on commit 1e75822

Please sign in to comment.