diff --git a/docs/api.rst b/docs/api.rst index cc03375..f997097 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -61,6 +61,7 @@ Selection of the desired metric result from multiple poses. OracleSelector TopSelector RandomSelector + DeviationAggregator Analysis functions ------------------ diff --git a/src/peppr/selector.py b/src/peppr/selector.py index c9016d0..f5d8782 100644 --- a/src/peppr/selector.py +++ b/src/peppr/selector.py @@ -5,6 +5,7 @@ "OracleSelector", "TopSelector", "RandomSelector", + "DeviationAggregator", ] from abc import ABC, abstractmethod @@ -174,3 +175,17 @@ def select(self, values: np.ndarray, smaller_is_better: bool) -> float: return np.nanmin(top_values) else: return np.nanmax(top_values) + + +class DeviationAggregator(Selector): + """ + Aggregator that computes the standard deviation of the values. This can be + used to assess the consistency of accuracy of a set of predicted poses. + """ + + @property + def name(self) -> str: + return "stdev" + + def select(self, values: np.ndarray, smaller_is_better: bool) -> float: + return np.nanstd(values) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index bfa32c1..555b617 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -72,3 +72,17 @@ def test_random_selector(): ] assert np.isclose(np.mean(selected_values), 9, rtol=0.5) + + +def test_deviation_aggregator(): + """ + This test verifies that the DeviationAggregator returns the expected + value of standard deviation for a given set of values. + """ + selector = peppr.DeviationAggregator() + values = np.linspace(0, 10, 10 + 1) + expected_std = np.std(values) + + selected_value = selector.select(values, smaller_is_better=False) + + assert np.isclose(selected_value, expected_std)