Skip to content

Commit

Permalink
started to add Hydra, need to finish
Browse files Browse the repository at this point in the history
  • Loading branch information
MaloOLIVIER committed Dec 2, 2024
1 parent d2b183b commit 945f05c
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 19 deletions.
9 changes: 9 additions & 0 deletions configs/callbacks/hnet_checkpoint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
save_last: True # additionally, always save model from last epoch
verbose: False
dirpath: "checkpoints/${now:%Y%m%d}/"
filename: "epoch_{epoch:01d}" #"hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}_epoch_{epoch:01d}"
monitor: "validation_loss"
save_top_k: 1
mode: "min"
3 changes: 3 additions & 0 deletions configs/callbacks/rich_model_summary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 3
9 changes: 9 additions & 0 deletions configs/logging/tensorboard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
log_every_n_steps: 100

logger:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "tb_logs/"
name: null
version: .
log_graph: False
default_hp_metric: False
24 changes: 24 additions & 0 deletions configs/run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
### Top-level config file for run.py

# Top-level variables available in all config files
description: "${hydra:runtime.choices.lightning_datamodule}: ${lightning_datamodule.id}"

# Hydra configuration
version_base: "1.3"
hydra:
job:
chdir: True # change working directory to the job directory
run:
dir: "outputs/run/${hydra:runtime.choices.lightning_datamodule}/${lightning_datamodule.id}/${now:%Y-%m-%d_%H-%M-%S}"

# Composing configs
defaults:
- lightning_datamodule: null # NEEDS TO BE OVERRIDDEN, will also determinate the metrics and the run directory
- lightning_module: null # NEEDS TO BE OVERRIDDEN
- trainer: ddp
- callbacks: # Dict of callbacks
# - bwe_checkpoint
- rich_model_summary
- logging: tensorboard
- metrics: ${lightning_datamodule}
- _self_ # priority is given to run.yaml for overrides
25 changes: 25 additions & 0 deletions configs/trainer/ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_target_: lightning.Trainer

# Hardware
accelerator: gpu
num_nodes: 1
devices: '0,'
strategy: 'ddp_find_unused_parameters_true'

# Epochs and batch sizes
max_epochs: -1
#accumulate_grad_batches: 2
#limit_train_batches: 20
#limit_val_batches: 10
#limit_test_batches: 0.05
#check_val_every_n_epoch: 15

# Logging
log_every_n_steps: ${logging.log_every_n_steps}
#track_grad_norm: 1

# Training
#overfit_batches: 1
#precision: 64
#weights_summary: "full"
#precision=16
43 changes: 24 additions & 19 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
from hungarian_net.lightning_modules.hnet_gru_lightning import HNetGRULightning


# @hydra.main(
# config_path="configs",
# config_name="run.yaml",
# version_base="1.3",
# )
def main(
batch_size=256,
@hydra.main(
config_path="configs",
config_name="run.yaml",
version_base="1.3",
)
def main(cfg: DictConfig):
""" batch_size=256,
nb_epochs=1000,
max_len=2,
sample_range_used=[3000, 5000, 15000],
filename_train="data/reference/hung_data_train",
filename_test="data/reference/hung_data_test",
):
filename_test="data/reference/hung_data_test", """



# TODO: Réécriture/factorisation du code sur le modèle de VibraVox de Julien HAURET
Expand Down Expand Up @@ -55,29 +55,34 @@ def main(
max_len=max_len,
)

# Instantiate LightningModule
lightning_module: LightningModule = hydra.utils.instantiate(
cfg.lightning_module,
metrics=metrics,
)

# Get current date
current_date = datetime.datetime.now().strftime("%Y%m%d")

os.makedirs(f"models/{current_date}", exist_ok=True)
os.makedirs(f"checkpoints/{current_date}", exist_ok=True)

# Human-readable filename
dirpath = f"models/{current_date}/"
dirpath = f"checkpoints/{current_date}/"
out_filename = f"hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}"

logger = TensorBoardLogger("tb_logs", name="hungarian_net")
checkpoint_callback = ModelCheckpoint(
""" checkpoint_callback = ModelCheckpoint(
dirpath=dirpath,
filename=out_filename,
monitor="validation_loss",
save_top_k=1,
mode="min",
)
) """

trainer = L.Trainer(
max_epochs=nb_epochs,
logger=logger,
callbacks=[checkpoint_callback],
# gpus=1 if use_cuda else 0
# Instantiate Trainer
callbacks: List[Callback] = list(hydra.utils.instantiate(cfg.callbacks).values())
logger: Logger = hydra.utils.instantiate(cfg.logging.logger)
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

trainer.fit(lightning_module, datamodule=lightning_datamodule)
Expand Down

0 comments on commit 945f05c

Please sign in to comment.