Skip to content

Consistency of accumulate initial parameter #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ._core import (
ScopedIter,
awaitify as _awaitify,
Sentinel,
borrow as _borrow,
)
from .builtins import (
Expand Down Expand Up @@ -64,9 +63,6 @@ async def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]:
yield item


__ACCUMULATE_SENTINEL = Sentinel("<no default>")


async def add(x: ADD, y: ADD) -> ADD:
"""The default reduction of :py:func:`~.accumulate`"""
return x + y
Expand All @@ -78,7 +74,7 @@ async def accumulate(
Callable[[Any, Any], Any], Callable[[Any, Any], Awaitable[Any]]
] = add,
*,
initial: Any = __ACCUMULATE_SENTINEL,
initial: Any = None,
) -> AsyncIterator[Any]:
"""
An :term:`asynchronous iterator` on the running reduction of ``iterable``
Expand All @@ -105,11 +101,7 @@ async def accumulate(iterable, function, *, initial):
"""
async with ScopedIter(iterable) as item_iter:
try:
value = (
initial
if initial is not __ACCUMULATE_SENTINEL
else await anext(item_iter)
)
value = initial if initial is not None else await anext(item_iter)
except StopAsyncIteration:
raise TypeError(
"accumulate() of empty sequence with no initial value"
Expand Down
6 changes: 5 additions & 1 deletion asyncstdlib/itertools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ from ._typing import AnyIterable, ADD, T, T1, T2, T3, T4, T5

def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]: ...
@overload
def accumulate(iterable: AnyIterable[ADD]) -> AsyncIterator[ADD]: ...
def accumulate(
iterable: AnyIterable[ADD], *, initial: None = ...
) -> AsyncIterator[ADD]: ...
@overload
def accumulate(iterable: AnyIterable[ADD], *, initial: ADD) -> AsyncIterator[ADD]: ...
@overload
def accumulate(
iterable: AnyIterable[T],
function: Callable[[T, T], T] | Callable[[T, T], Awaitable[T]],
*,
initial: None = ...,
) -> AsyncIterator[T]: ...
@overload
def accumulate(
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ Iterator transforming
.. autofunction:: accumulate(iterable: (async) iter T, function: (T, T) → (await) T = add [, initial: T])
:async-for: :T

.. versionchanged:: 3.13.2

``initial=None`` means no initial value is assumed.

.. autofunction:: starmap(function: (*A) → (await) T, iterable: (async) iter (A, ...))
:async-for: :T

Expand Down
12 changes: 12 additions & 0 deletions unittests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def reduction(x, y):

@sync
async def test_accumulate_default():
"""Test the default function of accumulate"""
for itertype in (asyncify, list):
assert await a.list(a.accumulate(itertype([0, 1]))) == list(
itertools.accumulate([0, 1])
Expand All @@ -53,10 +54,21 @@ async def test_accumulate_default():

@sync
async def test_accumulate_misuse():
"""Test wrong arguments to accumulate"""
with pytest.raises(TypeError):
assert await a.list(a.accumulate([]))


@sync
async def test_accumulate_initial():
"""Test the `initial` argument to accumulate"""
assert (
await a.list(a.accumulate(asyncify([1, 2, 3]), initial=None))
== await a.list(a.accumulate(asyncify([1, 2, 3])))
== list(itertools.accumulate([1, 2, 3], initial=None))
)


batched_cases = [
(range(10), 2, [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]),
(range(10), 3, [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)]),
Expand Down