diff --git a/src/pysorteddict/sorted_dict_view_type.cc b/src/pysorteddict/sorted_dict_view_type.cc index 1d8228c1..6aeeac21 100644 --- a/src/pysorteddict/sorted_dict_view_type.cc +++ b/src/pysorteddict/sorted_dict_view_type.cc @@ -57,7 +57,11 @@ void SortedDictViewIterType::track(RevIterType it) } if (it != this->sd->map->rend()) { - ++it->second.known_referrers; + FwdIterType it_base = it.base(); + if (it_base != this->sd->map->end()) + { + ++it_base->second.known_referrers; + } } else { @@ -68,7 +72,7 @@ void SortedDictViewIterType::track(RevIterType it) } /** - * Do all the necessary bookkeeping required to stop tracking the given + * Do all the necessary bookkeeping required to stop tracking the given forward * iterator of the underlying sorted dictionary. * * The caller should ensure that this method is called immediately after the @@ -76,12 +80,31 @@ void SortedDictViewIterType::track(RevIterType it) * * @param it Previous value of the iterator member. */ -template -void SortedDictViewIterType::untrack(T it) +template<> +void SortedDictViewIterType::untrack(FwdIterType it) { --it->second.known_referrers; } +/** + * Do all the necessary bookkeeping required to stop tracking the given reverse + * iterator of the underlying sorted dictionary. + * + * The caller should ensure that this method is called immediately after the + * iterator member is updated. + * + * @param it Previous value of the iterator member. + */ +template<> +void SortedDictViewIterType::untrack(RevIterType it) +{ + FwdIterType it_base = it.base(); + if (it_base != this->sd->map->end()) + { + --it_base->second.known_referrers; + } +} + template void SortedDictViewIterType::Delete(PyObject* self) { diff --git a/tests/functional/test_keys_iter.py b/tests/functional/test_keys_iter.py index c1ab4938..9c6411f2 100644 --- a/tests/functional/test_keys_iter.py +++ b/tests/functional/test_keys_iter.py @@ -73,3 +73,18 @@ def test_destructive_forward_iteration(sorted_dict): del sorted_dict[key] assert len(sorted_dict) == 0 assert not [*sorted_dict] + + +@pytest.mark.parametrize("sorted_dict", [*range(10), 100, 1_000, 10_000, 100_000], indirect=True) +def test_destructive_reverse_iteration(sorted_dict): + prev_key = None + for key in reversed(sorted_dict): + # A quirk of the implementation of reverse iterators: the current key + # cannot be deleted. + if prev_key is not None: + del sorted_dict[prev_key] + prev_key = key + if prev_key is not None: + del sorted_dict[prev_key] + assert len(sorted_dict) == 0 + assert not [*sorted_dict]