diff --git a/streamz/core.py b/streamz/core.py index 00d9f4db..92948cb0 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -2,6 +2,7 @@ from collections import deque from datetime import timedelta +import copy import functools import logging import six @@ -660,76 +661,60 @@ class accumulate(Stream): This performs running or cumulative reductions, applying the function to the previous total and the new element. The function should take two arguments, the previous accumulated state and the next element and - it should return a new accumulated state, - - ``state = func(previous_state, new_value)`` (returns_state=False) - - ``state, result = func(previous_state, new_value)`` (returns_state=True) - - where the new_state is passed to the next invocation. The state or result - is emitted downstream for the two cases. + it should return a new accumulated state. Parameters ---------- func: callable start: object - Initial value, passed as the value of ``previous_state`` on the first - invocation. Defaults to the first submitted element - returns_state: boolean + Initial value. Defaults to the first submitted element + returns_state: boolean, optional If true then func should return both the state and the value to emit If false then both values are the same, and func returns one value + reset_stream : Stream instance or None, optional + If not None, when the ``reset_stream`` stream emits the accumulate + node's state will revert to the initial state (set by the ``start`` + kwarg), defaults to None **kwargs: Keyword arguments to pass to func Examples -------- - A running total, producing triangular numbers - >>> source = Stream() >>> source.accumulate(lambda acc, x: acc + x).sink(print) >>> for i in range(5): ... source.emit(i) - 0 1 3 6 10 - - A count of number of events (including the current one) - - >>> source = Stream() - >>> source.accumulate(lambda acc, x: acc + 1, start=0).sink(print) - >>> for _ in range(5): - ... source.emit(0) - 1 - 2 - 3 - 4 - 5 - - Like the builtin "enumerate". - - >>> source = Stream() - >>> source.accumulate(lambda acc, x: ((acc[0] + 1, x), (acc[0], x)), - ... start=(0, 0), returns_state=True - ... ).sink(print) - >>> for i in range(3): - ... source.emit(0) - (0, 0) - (1, 0) - (2, 0) """ - _graphviz_shape = 'box' - def __init__(self, upstream, func, start=no_default, returns_state=False, - **kwargs): + _graphviz_shape = "box" + + def __init__( + self, upstream, func, start=no_default, returns_state=False, + reset_stream=None, **kwargs + ): self.func = func self.kwargs = kwargs self.state = start + # XXX: maybe don't need deepcopy? + self._initial_state = copy.deepcopy(start) self.returns_state = returns_state + self.reset_node = reset_stream # this is one of a few stream specific kwargs - stream_name = kwargs.pop('stream_name', None) - Stream.__init__(self, upstream, stream_name=stream_name) + stream_name = kwargs.pop("stream_name", None) + if reset_stream: + Stream.__init__(self, upstreams=[upstream, reset_stream], + stream_name=stream_name) + else: + Stream.__init__(self, upstream, stream_name=stream_name) def update(self, x, who=None): + if who is self.reset_node: + self.state = copy.deepcopy(self._initial_state) + return if self.state is no_default: self.state = x return self._emit(x) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 602c8be2..8c12b0a8 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -1164,5 +1164,40 @@ def start(self): assert flag == [True] +def test_accumulate(): + a = Stream() + b = a.accumulate(lambda x, y: x + y) + L = b.sink_to_list() + LL = [] + + for i in range(10): + a.emit(i) + if len(LL) == 0: + LL.append(i) + else: + LL.append(i + LL[-1]) + + assert L == LL + + +def test_accumulate_reset(): + a = Stream() + rn = Stream() + b = a.accumulate(lambda x, y: x + y, reset_stream=rn) + L = b.sink_to_list() + LL = [] + + for i in range(10): + if i == 5: + rn.emit('hi') + a.emit(i) + if len(LL) == 0 or i == 5: + LL.append(i) + else: + LL.append(i + LL[-1]) + + assert L == LL + + if sys.version_info >= (3, 5): from streamz.tests.py3_test_core import * # noqa