Skip to content

Commit

Permalink
reorganise test folder
Browse files Browse the repository at this point in the history
  • Loading branch information
MarineHoche committed Sep 29, 2023
1 parent 4d37ff4 commit 510297a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 256 deletions.
168 changes: 0 additions & 168 deletions tests/conftest.py

This file was deleted.

7 changes: 0 additions & 7 deletions tests/fairness_check/conftest.py

This file was deleted.

77 changes: 68 additions & 9 deletions tests/fairness_check/test_analysis_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from pathlib import Path

import numpy as np
import numpy.testing as npt
import pytest

from famews.fairness_check.metrics import (
get_corrected_npv,
get_corrected_precision,
get_npv,
get_precision,
)
from famews.scripts.run_fairness_analysis import main as fairness_main
from famews.fairness_check.metrics import (get_corrected_npv,
get_corrected_precision, get_npv,
get_precision)
from famews.fairness_check.utils.compute_threshold_score import \
GetThresholdScore
from famews.pipeline import PipelineState


def test_prevalence_correction_binary_metrics():
Expand All @@ -20,7 +18,9 @@ def test_prevalence_correction_binary_metrics():
prev_baseline_npv = 0.1
prev = 0.3
precision = get_precision(y_true, y_pred)
corrected_precision = get_corrected_precision(y_true, y_pred, prev_baseline_precision, prev)
corrected_precision = get_corrected_precision(
y_true, y_pred, prev_baseline_precision, prev
)
assert (
corrected_precision >= precision
), f"Prevalence correction lead to a smaller precision: {corrected_precision} vs {precision}"
Expand All @@ -29,3 +29,62 @@ def test_prevalence_correction_binary_metrics():
assert (
corrected_npv >= npv
), f"Prevalence correction lead to a smaller NPV: {corrected_npv} vs {npv}"


@pytest.mark.parametrize(
"name_threshold",
["recall_0.25", "precision_0.40", "npv_0.625", "fpr_0.42", "event_recall_0.5"],
)
def test_compute_threshold(name_threshold: str):
state = PipelineState()
state.name_threshold = name_threshold
state.predictions = {
1: (
np.array([0.6, 0.1, 0.1, 0, 0, 0.3, 0.3]),
np.array([0, 1, 1, np.nan, np.nan, 0, 0]),
),
2: (
np.array([0.2, 0.15, 0.25, 0.4, 0.6, 0, 0, 0, 0, 0.45, 0.45]),
np.array(
[
0,
0,
0,
1,
1,
np.nan,
np.nan,
np.nan,
np.nan,
0,
0,
]
),
),
}
state.event_bounds = {1: [(3, 5)], 2: [(5, 9)]}
state.horizon = 2 * 5 / 60
state.max_len = 20
state.timestep = 5
get_threshold_stage = GetThresholdScore(state)
get_threshold_stage.run()
if get_threshold_stage.metric_name == "recall":
assert (
state.threshold > 0.4
), f"Threshold for target recall at 0.25 is wrong, expected 0.4 but got {state.threshold}"
elif get_threshold_stage.metric_name == "precision":
assert (
state.threshold > 0.3
), f"Threshold for target precision at 0.40 is wrong, expected 0.3 but got {state.threshold}"
elif get_threshold_stage.metric_name == "npv":
assert (
state.threshold > 0.4
), f"Threshold for target NPV at 0.625 is wrong, expected 0.4 but got {state.threshold}"
elif get_threshold_stage.metric_name == "fpr":
assert (
state.threshold > 0.3
), f"Threshold for target recall at 0.42 is wrong, expected 0.3 but got {state.threshold}"
elif get_threshold_stage.metric_name == "event_recall":
assert (
state.threshold > 0.1
), f"Threshold for target event-based recall at 0.5 is wrong, expected 0.1 but got {state.threshold}"
72 changes: 0 additions & 72 deletions tests/fairness_check/test_setup.py

This file was deleted.

0 comments on commit 510297a

Please sign in to comment.