Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 26 additions & 41 deletions streamz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import deque
from datetime import timedelta
import copy
import functools
import logging
import six
Expand Down Expand Up @@ -660,76 +661,60 @@ class accumulate(Stream):
This performs running or cumulative reductions, applying the function
to the previous total and the new element. The function should take
two arguments, the previous accumulated state and the next element and
it should return a new accumulated state,
- ``state = func(previous_state, new_value)`` (returns_state=False)
- ``state, result = func(previous_state, new_value)`` (returns_state=True)

where the new_state is passed to the next invocation. The state or result
is emitted downstream for the two cases.
it should return a new accumulated state.

Parameters
----------
func: callable
start: object
Initial value, passed as the value of ``previous_state`` on the first
invocation. Defaults to the first submitted element
returns_state: boolean
Initial value. Defaults to the first submitted element
returns_state: boolean, optional
If true then func should return both the state and the value to emit
If false then both values are the same, and func returns one value
reset_stream : Stream instance or None, optional
If not None, when the ``reset_stream`` stream emits the accumulate
node's state will revert to the initial state (set by the ``start``
kwarg), defaults to None
**kwargs:
Keyword arguments to pass to func

Examples
--------
A running total, producing triangular numbers

>>> source = Stream()
>>> source.accumulate(lambda acc, x: acc + x).sink(print)
>>> for i in range(5):
... source.emit(i)
0
1
3
6
10

A count of number of events (including the current one)

>>> source = Stream()
>>> source.accumulate(lambda acc, x: acc + 1, start=0).sink(print)
>>> for _ in range(5):
... source.emit(0)
1
2
3
4
5

Like the builtin "enumerate".

>>> source = Stream()
>>> source.accumulate(lambda acc, x: ((acc[0] + 1, x), (acc[0], x)),
... start=(0, 0), returns_state=True
... ).sink(print)
>>> for i in range(3):
... source.emit(0)
(0, 0)
(1, 0)
(2, 0)
"""
_graphviz_shape = 'box'

def __init__(self, upstream, func, start=no_default, returns_state=False,
**kwargs):
_graphviz_shape = "box"

def __init__(
self, upstream, func, start=no_default, returns_state=False,
reset_stream=None, **kwargs
):
self.func = func
self.kwargs = kwargs
self.state = start
# XXX: maybe don't need deepcopy?
self._initial_state = copy.deepcopy(start)
self.returns_state = returns_state
self.reset_node = reset_stream
# this is one of a few stream specific kwargs
stream_name = kwargs.pop('stream_name', None)
Stream.__init__(self, upstream, stream_name=stream_name)
stream_name = kwargs.pop("stream_name", None)
if reset_stream:
Stream.__init__(self, upstreams=[upstream, reset_stream],
stream_name=stream_name)
else:
Stream.__init__(self, upstream, stream_name=stream_name)

def update(self, x, who=None):
if who is self.reset_node:
self.state = copy.deepcopy(self._initial_state)
return
if self.state is no_default:
self.state = x
return self._emit(x)
Expand Down
35 changes: 35 additions & 0 deletions streamz/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,5 +1164,40 @@ def start(self):
assert flag == [True]


def test_accumulate():
a = Stream()
b = a.accumulate(lambda x, y: x + y)
L = b.sink_to_list()
LL = []

for i in range(10):
a.emit(i)
if len(LL) == 0:
LL.append(i)
else:
LL.append(i + LL[-1])

assert L == LL


def test_accumulate_reset():
a = Stream()
rn = Stream()
b = a.accumulate(lambda x, y: x + y, reset_stream=rn)
L = b.sink_to_list()
LL = []

for i in range(10):
if i == 5:
rn.emit('hi')
a.emit(i)
if len(LL) == 0 or i == 5:
LL.append(i)
else:
LL.append(i + LL[-1])

assert L == LL


if sys.version_info >= (3, 5):
from streamz.tests.py3_test_core import * # noqa