Skip to content

Commit

Permalink
Fixed typed iterator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed May 23, 2024
1 parent b49ed60 commit ee93e22
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 70 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- ``custom_hash`` the internally used hashing method based on pickle is now part of the public API via ``tpcp.misc``.

### Changed

- Relative large rework of the TypedIterator. We recommend to reread the example.

## [0.32.0] - 2024-04-17

- The snapshot plugin now supports a new command line argument `--snapshot-only-check` that will fail the test if no
Expand Down
4 changes: 2 additions & 2 deletions examples/recipies/_04_typed_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class QRSResultType:
# Note that we can type these functions using the `TypedIteratorResultTuple` type.
# Like the iterator itself, this type is generic and allows you to specify the input and output types.
# So in our case, the input is `ECGExampleData` and the output is `QRSResultType`.
from typing_extensions import TypeAlias, reveal_type
from typing_extensions import TypeAlias

from tpcp.misc import TypedIteratorResultTuple

Expand Down Expand Up @@ -297,4 +297,4 @@ class SimpleResultType:

custom_iterator.results_
# %%
custom_iterator.additional_results_
custom_iterator.additional_results_
12 changes: 5 additions & 7 deletions tests/test_examples/test_all_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas as pd
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from pandas.testing import assert_frame_equal

matplotlib.use("Agg")

Expand Down Expand Up @@ -165,12 +164,11 @@ def test_caching_example():


def test_typed_iterator_example():
from examples.recipies._04_typed_iterator import custom_iterator, iterator
from examples.recipies._04_typed_iterator import custom_iterator, qrs_iterator

assert len(iterator.r_peak_positions_) == 17782
assert sum(iterator.n_r_peaks_.values()) == 17782
assert len(iterator.raw_results_) == 12
assert len(qrs_iterator.results_.r_peak_positions) == 17782
assert sum(qrs_iterator.results_.n_r_peaks.values()) == 17782
assert len(qrs_iterator.raw_results_) == 12

assert len(custom_iterator.n_samples_) == 2
assert len(custom_iterator.results_.n_samples) == 2
assert len(custom_iterator.raw_results_) == 2
assert_frame_equal(custom_iterator.inputs_[0], pd.DataFrame({"data": [1, 2, 3, 4, 5]}))
78 changes: 32 additions & 46 deletions tests/test_misc/test_typed_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,24 @@ def test_simple_no_agg():
check_data.append(d)
check_results.append(r)

assert check_data == data == iterator.inputs_
assert check_results == iterator.raw_results_
assert iterator.result_1_ == [0, 1, 2]
assert iterator.result_2_ == [0, 2, 4]
assert iterator.result_3_ == [0, 3, 6]
inputs = [r.input for r in iterator.raw_results_]
assert check_data == data == inputs
raw_output = [r.result for r in iterator.raw_results_]
assert check_results == raw_output
assert iterator.results_.result_1 == [0, 1, 2]
assert iterator.results_.result_2 == [0, 2, 4]
assert iterator.results_.result_3 == [0, 3, 6]


def test_simple_with_agg():
rt = make_dataclass("ResultType", ["result_1", "result_2", "result_3"])

iterator = TypedIterator[rt](
rt, aggregations=[("result_1", lambda i, r: sum(i)), ("result_2", lambda i, r: sum(r))]
iterator = TypedIterator(
rt,
aggregations=[
("result_1", lambda re: sum(r.input for r in re)),
("result_2", lambda re: sum(r.result.result_2 for r in re)),
],
)

data = [1, 2, 3]
Expand All @@ -51,9 +57,9 @@ def test_simple_with_agg():
result_obj = iterator.results_

assert isinstance(result_obj, rt)
assert iterator.result_1_ == result_obj.result_1 == 6
assert iterator.result_2_ == result_obj.result_2 == 12
assert iterator.result_3_ == result_obj.result_3 == [3, 6, 9]
assert iterator.results_.result_1 == 6
assert iterator.results_.result_2 == 12
assert iterator.results_.result_3 == [3, 6, 9]


def test_warning_incomplete_iterate():
Expand All @@ -63,61 +69,41 @@ def test_warning_incomplete_iterate():

next(iterator.iterate(data))
with pytest.warns(UserWarning):
partial_results = iterator.raw_results_
partial_results = [r.result for r in iterator.raw_results_]

assert partial_results == [
rt(result_1=TypedIterator.NULL_VALUE, result_2=TypedIterator.NULL_VALUE, result_3=TypedIterator.NULL_VALUE)
]

with pytest.warns(UserWarning):
partial_results = iterator.result_1_
partial_results = iterator.results_.result_1

assert partial_results == [TypedIterator.NULL_VALUE]


def test_invalid_attr_error():
field_names = ["result_1", "result_2", "result_3"]
rt = make_dataclass("ResultType", field_names)
iterator = TypedIterator(rt)
data = [1, 2, 3]

[next(iterator.iterate(data)) for _ in range(3)]

with pytest.raises(AttributeError) as e:
iterator.invalid_attr_

assert "invalid_attr_" in str(e.value)
for f in field_names:
assert f"{f}_" in str(e.value)


def test_not_allowed_attr_error():
field_names = ["results"]

rt = make_dataclass("ResultType", field_names)
iterator = TypedIterator(rt)
data = [1, 2, 3]

with pytest.raises(ValueError):
[next(iterator.iterate(data)) for _ in range(3)]


def test_agg_with_empty():
def test_additional_aggregations():
rt = make_dataclass("ResultType", ["result_1", "result_2", "result_3"])

iterator = TypedIterator[rt](
rt, aggregations=[("result_1", lambda i, r: sum(i)), ("result_2", lambda i, r: sum(r))]
iterator = TypedIterator(
rt,
aggregations=[
("result_1", lambda re: sum(r.input for r in re)),
("result_2", lambda re: sum(r.result.result_2 for r in re)),
("additional_results", lambda re: sum(r.result.result_2 + 1 for r in re)),
],
)

data = [1, 2, 3]
for i, r in iterator.iterate(data):
r.result_1 = i - 1
# We Don't set result 2 -> it will remain an empty value and should skip agg
r.result_2 = i * 2
r.result_3 = i * 3

result_obj = iterator.results_

assert isinstance(result_obj, rt)
assert iterator.result_1_ == result_obj.result_1 == 6
assert iterator.result_2_ == [iterator.NULL_VALUE, iterator.NULL_VALUE, iterator.NULL_VALUE]
assert iterator.result_3_ == result_obj.result_3 == [3, 6, 9]
assert iterator.results_.result_1 == 6
assert iterator.results_.result_2 == 12
assert iterator.results_.result_3 == [3, 6, 9]

assert iterator.additional_results_["additional_results"] == iterator.results_.result_2 + len(data)
25 changes: 10 additions & 15 deletions tpcp/misc/_typed_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class BaseTypedIterator(Algorithm, Generic[InputTypeT, DataclassT]):
An optional list of aggregations to apply to the results.
This has the form ``[(result_name, aggregation_function), ...]``.
Each aggregation function gets ``raw_results_`` provided as input and can return an arbitrary object.
Note, that the aggregation function needs to handle the case where no result was set for a specific attribute.
If a result-name is in the list, the aggregation will be applied to it, when accessing the ``results_``
(i.e. ``results_.{result_name}``).
If no aggregation is defined for a result, a simple list of all results will be returned.
Expand Down Expand Up @@ -114,6 +115,7 @@ class BaseTypedIterator(Algorithm, Generic[InputTypeT, DataclassT]):
_result_fields: set[str]
# We use this as cache
_results: DataclassT
_additional_results: dict[str, Any]

done_: dict[str, bool]

Expand Down Expand Up @@ -165,18 +167,9 @@ def _iterate(
# Reset all caches
if hasattr(self, "_results"):
del self._results
if hasattr(self, "_add"):
del self._additional_results
self.done_ = {}

result_field_names = {f.name for f in fields(self.data_type)}
not_allowed_fields = {"results", "raw_results", "done", "inputs"}
if not_allowed_fields.intersection(result_field_names):
raise ValueError(
f"The result dataclass cannot have a field called {not_allowed_fields}. "
"These fields are used by the TypedIterator to store the results. "
"Having these fields in the result object will result in naming conflicts."
)

self._result_fields = result_field_names
self._raw_results = []

self.done_[iteration_name] = False
Expand Down Expand Up @@ -223,14 +216,15 @@ def _agg_result(self, raw_results: list[IteratorResult]) -> tuple[dict[str, Any]
values = self._get_default_agg(name)(raw_results)
agg_results[name] = values
# If there are further aggregations, we apply them as well
additional_aggregations = {name: agg(raw_results) for name, agg in aggregations.items() if name not in agg_results}
additional_aggregations = {
name: agg(raw_results) for name, agg in aggregations.items() if name not in agg_results
}
return agg_results, additional_aggregations


def _cache_agg(self) -> None:
agg_results, additional_aggregations = self._agg_result(self.raw_results_)
self._results = self.data_type(**agg_results)
self._additional_aggregations = additional_aggregations
self._additional_results = additional_aggregations

@property
def results_(self) -> DataclassT:
Expand All @@ -244,11 +238,12 @@ def results_(self) -> DataclassT:
if not hasattr(self, "_results"):
self._cache_agg()
return self._results

@property
def additional_results_(self) -> dict[str, Any]:
if not hasattr(self, "_additional_results"):
self._cache_agg()
return self._additional_aggregations
return self._additional_results

@classmethod
def filter_iterator_results(
Expand Down

0 comments on commit ee93e22

Please sign in to comment.