Skip to content

Commit

Permalink
Enhance checkpoint and logging configurations; update test setups
Browse files Browse the repository at this point in the history
- Update configs/callbacks/hnet_checkpoint.yaml to:
  - Save checkpoints under trained_on_/standard_resolution/
  - Append timestamp to filenames for better traceability

- Modify configs/logging/tensorboard.yaml to:
  - Change logger name to trained_on_/standard_resolution
  - Update version format to include only hnet_model_DOA

- Update pytest.ini to include new markers:
  - Add scenarios_run for tests that train the algorithm
  - Add scenarios_test for tests that evaluate the algorithm

- Improve documentation in tests/scenarios_tests/conftest.py:
  - Add Args and Returns sections to fixtures for clarity
  - Update descriptions to accurately reflect fixture behavior

- Clean up tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py:
  - Remove unused numpy import

- Update data paths and rename test function in tests/scenarios_tests/run/test_scenarios_run.py:
  - Rename test_train_hnetgru_under_various_distributions to test_run_under_various_distributions for clarity
  • Loading branch information
MaloOLIVIER committed Dec 6, 2024
1 parent b104e93 commit cae3588
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 24 deletions.
4 changes: 2 additions & 2 deletions configs/callbacks/hnet_checkpoint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
save_last: True # additionally, always save model from last epoch
verbose: False
dirpath: "${hydra:runtime.cwd}/checkpoints/${now:%Y%m%d}/"
filename: "hnet_model_DOA${max_len}_${sample_range_used}_{epoch:01d}"
dirpath: "${hydra:runtime.cwd}/checkpoints/${now:%Y%m%d}/trained_on_${sample_range_trained_on}/standard_resolution/"
filename: "hnet_model_DOA${max_len}_${sample_range_used}_{epoch:01d}_{now:%H%M%S}"
monitor: "validation_loss"
save_top_k: 1
mode: "min"
4 changes: 2 additions & 2 deletions configs/logging/tensorboard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ log_every_n_steps: 100
logger:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${hydra:runtime.cwd}/tb_logs/"
name: 'hnet_model'
version: "hnet_model_DOA${max_len}_${sample_range_used}"
name: 'trained_on_${sample_range_trained_on}/standard_resolution'
version: "hnet_model_DOA${max_len}"
log_graph: False
default_hp_metric: False
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ addopts = --cache-clear --strict-markers -vv --capture=tee-sys --cov=hungarian_n
markers =
consistency: mark tests for consistency checks
scenarios: mark tests for scenario-based tests
scenarios_run: mark tests for scenario-based tests that train the algorithm
scenarios_test: mark tests for scenario-based tests that test the algorithm
scenarios_generate_data: mark tests for scenario-based tests that generate data
nonregression: mark non-regression tests to ensure existing functionality is not broken
3 changes: 1 addition & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def main(cfg: DictConfig):
dict(hydra.utils.instantiate(cfg.metrics))
)
lightning_module: L.LightningModule = hydra.utils.instantiate(
cfg.lightning_module,
metrics=metrics
cfg.lightning_module, metrics=metrics
)

# Instantiate Trainer
Expand Down
9 changes: 8 additions & 1 deletion tests/scenarios_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def batch_size(request) -> int:
and behavior with various batch sizes. This helps in ensuring that the model scales
appropriately with different amounts of data processed in each training iteration.
Args:
request (FixtureRequest): Pytest's fixture request object that provides access to the
parameters specified in the `params` list.
Returns:
int: The current value of `batch_size` for the test iteration.
Example:
When used in a test, `batch_size` will sequentially take the values 64, 128, and 256.
"""
Expand Down Expand Up @@ -80,7 +87,7 @@ def nb_epochs(request) -> int:
parameters specified in the `params` list.
Returns:
int: The current `nb_epochs` for the test iteration.
int: The current number of training epochs for the test iteration.
"""
return request.param

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py

import numpy as np
import pytest

from generate_hnet_training_data import main
Expand Down
34 changes: 18 additions & 16 deletions tests/scenarios_tests/run/test_scenarios_run.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,48 @@
# tests/scenarios_tests/run/test_scenarios_run.py

import os
from pathlib import Path
import re
from pathlib import Path

import pytest


@pytest.mark.scenarios
@pytest.mark.scenarios_run
@pytest.mark.parametrize(
"training_data, test_data",
[
(
"data/20241205/train/hung_data_train_DOA2_3000-5000-15000",
"data/20241205/test/hung_data_test_DOA2_3000-5000-15000",
"data/20241206/train/hung_data_train_DOA2_3000-5000-15000",
"data/20241206/test/hung_data_test_DOA2_3000-5000-15000",
),
(
"data/20241205/train/hung_data_train_DOA2_5000-5000-5000",
"data/20241205/test/hung_data_test_DOA2_5000-5000-5000",
"data/20241206/train/hung_data_train_DOA2_5000-5000-5000",
"data/20241206/test/hung_data_test_DOA2_5000-5000-5000",
),
(
"data/20241205/train/hung_data_train_DOA2_1000-3000-31000",
"data/20241205/test/hung_data_test_DOA2_1000-3000-31000",
"data/20241206/train/hung_data_train_DOA2_1000-3000-31000",
"data/20241206/test/hung_data_test_DOA2_1000-3000-31000",
),
(
"data/20241205/train/hung_data_train_DOA2_2600-5000-17000",
"data/20241205/test/hung_data_test_DOA2_2600-5000-17000",
"data/20241206/train/hung_data_train_DOA2_2600-5000-17000",
"data/20241206/test/hung_data_test_DOA2_2600-5000-17000",
),
(
"data/20241205/train/hung_data_train_DOA2_6300-4000-1500",
"data/20241205/test/hung_data_test_DOA2_6300-4000-1500",
"data/20241206/train/hung_data_train_DOA2_6300-4000-1500",
"data/20241206/test/hung_data_test_DOA2_6300-4000-1500",
),
(
"data/20241205/train/hung_data_train_DOA2_2000-7000-14000",
"data/20241205/test/hung_data_test_DOA2_2000-7000-14000",
"data/20241206/train/hung_data_train_DOA2_2000-7000-14000",
"data/20241206/test/hung_data_test_DOA2_2000-7000-14000",
),
(
"data/20241205/train/hung_data_train_DOA2_2500-8000-8500",
"data/20241205/test/hung_data_test_DOA2_2500-8000-8500",
"data/20241206/train/hung_data_train_DOA2_2500-8000-8500",
"data/20241206/test/hung_data_test_DOA2_2500-8000-8500",
),
],
)
def test_train_hnetgru_under_various_distributions(training_data, test_data):
def test_run_under_various_distributions(training_data, test_data):
"""
Train the HNetGRU model with various data distributions.
Expand Down

0 comments on commit cae3588

Please sign in to comment.