Skip to content

Commit

Permalink
reshaped tests directory
Browse files Browse the repository at this point in the history
  • Loading branch information
MaloOLIVIER committed Dec 3, 2024
1 parent 82811e1 commit 5420644
Show file tree
Hide file tree
Showing 22 changed files with 90 additions and 164 deletions.
3 changes: 2 additions & 1 deletion tests/consistency_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# tests/consistency_tests/conftest.py

import pytest

from hungarian_net.torch_modules.attention_layer import AttentionLayer
from hungarian_net.torch_modules.hnet_gru import HNetGRU
from hungarian_net.torch_modules.hnet_gru import HNetGRU

# TODO: maybe rewrite docstrings

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# tests/consistency_tests/generate_hnet_training_data/test_generate_hnet_training_data.py
# tests/consistency_tests/generate_hnet_training_data/tests_consistency_generate_hnet_training_data.py

import numpy as np
import pytest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/consistency_tests/lightning_datamodules/tests_consistency_hungarian_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# tests/consistency_tests/lightning_modules/tests_consistency_hnet_gru_lightning.py

import pytest
3 changes: 3 additions & 0 deletions tests/consistency_tests/run/tests_consistency_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# tests/consistency_tests/run/tests_consistency_run.py


Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/consistency_tests/torch_modules/tests_consistency_attention_layer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# tests/consistency_tests/model/test_train_hnet.py
# tests/consistency_tests/torch_modules/tests_consistency_hnet_gru.py

import pytest
import torch

from hungarian_net.torch_modules.attention_layer import AttentionLayer
from hungarian_net.torch_modules.hnet_gru import HNetGRU

# TODO: maybe rewrite docstrings
from hungarian_net.torch_modules.hnet_gru import HNetGRU


@pytest.mark.consistency
def test_model_initialization(model, max_doas) -> None:
def test_HNetGRU_init(model, max_doas) -> None:
"""Test the initialization of the HNetGRU model.
Args:
Expand All @@ -25,9 +21,8 @@ def test_model_initialization(model, max_doas) -> None:
model.max_len == max_doas
), f"Expected max_doas {max_doas}, got {model.max_len}"


@pytest.mark.consistency
def test_forward_pass(model, batch_size) -> None:
def test_HNetGRU_forward(model, batch_size) -> None:
"""Test the forward pass of the HNetGRU model to ensure correct output shapes.
Args:
Expand Down Expand Up @@ -55,22 +50,4 @@ def test_forward_pass(model, batch_size) -> None:
assert output3.shape == (
batch_size,
model.max_len,
), f"Expected output3 shape {(batch_size, model.max_len)}, got {output3.shape}"


@pytest.mark.consistency
def test_attention_layer_initialization(attentionLayer) -> None:
"""Test the initialization of the AttentionLayer.
Args:
attentionLayer (AttentionLayer): The AttentionLayer instance provided by the fixture.
Returns:
None
"""
assert isinstance(
attentionLayer, AttentionLayer
), f"AttentionLayer is not an instance of AttentionLayer class, got {attentionLayer.__repr__()}"


# TODO: write test for compute_weight_accuracy
), f"Expected output3 shape {(batch_size, model.max_len)}, got {output3.shape}"
3 changes: 2 additions & 1 deletion tests/nonregression_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# tests/nonregression_tests/conftest.py

import pytest
from hungarian_net.torch_modules.hnet_gru import HNetGRU

from hungarian_net.torch_modules.hnet_gru import HNetGRU


@pytest.fixture(params=[256])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# tests/nonregression_tests/generate_hnet_training_data/test_nonregression_generate_hnet_training_data.py
# tests/nonregression_tests/generate_hnet_training_data/tests_nonregression_generate_hnet_training_data.py
import pytest
from pytest_mock import mocker

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/nonregression_tests/lightning_datamodules/tests_nonregression_hungarian_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# tests/nonregression_tests/lightning_modules/tests_nonregression_hnet_gru_lightning.py

import pytest
120 changes: 0 additions & 120 deletions tests/nonregression_tests/model/test_nonregression_train_hnet.py

This file was deleted.

20 changes: 20 additions & 0 deletions tests/nonregression_tests/run/tests_nonregression_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# tests/nonregression_tests/run/tests_nonregression_run.py

import os

import pytest
import torch
from pytest_mock import mocker

from hungarian_net.torch_modules.hnet_gru import HNetGRU
from run import main as train_main
from run import set_seed

# TODO: Performing a non-regression test by directly comparing a newly trained model with a reference model is ineffective due to inherent numerical computation errors that can cause discrepancies.
# TODO: In future iterations, it would be more effective to assess regression by evaluating the model's individual components (e.g., functions, classes, methods) to ensure each part operates as expected without being affected by numerical inaccuracies.


@pytest.mark.nonregression
def test_non_regression_train_hnet(mocker):

set_seed()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/nonregression_tests/torch_modules/tests_nonregression_attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/nonregression_tests/torch_modules/tests_nonregression_hnet_gru.py
3 changes: 2 additions & 1 deletion tests/scenarios_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np
import pytest
from hungarian_net.torch_modules.hnet_gru import HNetGRU

from hungarian_net.torch_modules.hnet_gru import HNetGRU


@pytest.fixture(params=[2])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# tests/scenarios_tests/test_generate_hnet_training_data.py
# tests/scenarios_tests/generate_hnet_training_data/tests_scenarios_generate_hnet_training_data.py

import numpy as np
import pytest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/scenarios_tests/lightning_datamodules/tests_scenarios_hungarian_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# tests/scenarios_tests/lightning_modules/tests_scenarios_hnet_gru_lightning.py

import pytest
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# tests/scenarios_tests/model/test_train_hnet.py
# tests/scenarios_tests/run/tests_scenarios_run.py

import re

import pytest

from run import main


Expand Down Expand Up @@ -40,7 +42,7 @@
],
)
def test_train_model_under_various_distributions(
max_doas, batch_size, nb_epochs, training_data, test_data
# max_doas, batch_size, nb_epochs, training_data, test_data
):
"""
Train the HNetGRU model with various data distributions.
Expand All @@ -53,6 +55,9 @@ def test_train_model_under_various_distributions(
test_data (str): Path to the testing data file.
"""

# TODO: to re-work
# TODO: next step : train hnet model under various data distributions

# Extract sample ranges from the training_data filename
match = re.search(r"hung_data_train_DOA\d+_(\d+)-(\d+)-(\d+)", training_data)
if match:
Expand All @@ -63,11 +68,11 @@ def test_train_model_under_various_distributions(
# Mock nb_epochs to be 1 regardless of the input
nb_epochs = 1

main(
batch_size=batch_size,
nb_epochs=nb_epochs,
max_len=max_doas,
sample_range_used=sample_range_used,
filename_train=training_data,
filename_test=test_data,
)
# main(
# batch_size=batch_size,
# nb_epochs=nb_epochs,
# max_len=max_doas,
# sample_range_used=sample_range_used,
# filename_train=training_data,
# filename_test=test_data,
# )
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# tests/scenarios_tests/torch_modules/tests_scenarios_attention_layer.py

import pytest

from hungarian_net.torch_modules.attention_layer import AttentionLayer


@pytest.mark.consistency
def test_AttentionLayer_init(attentionLayer) -> None:
"""Test the initialization of the AttentionLayer.
Args:
attentionLayer (AttentionLayer): The AttentionLayer instance provided by the fixture.
Returns:
None
"""
assert isinstance(
attentionLayer, AttentionLayer
), f"AttentionLayer is not an instance of AttentionLayer class, got {attentionLayer.__repr__()}"


Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# tests/scenarios_tests/torch_modules/tests_scenarios_hnet_gru.py

0 comments on commit 5420644

Please sign in to comment.