Skip to content

Commit 3ba10cd

Browse files
committed
add executor.map() functionality
1 parent 4e0ef19 commit 3ba10cd

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

src/qasync/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,25 @@ def submit(self, callback, *args, **kwargs):
195195
return future
196196

197197
def map(self, func, *iterables, timeout=None):
198-
raise NotImplementedError("use as_completed on the event loop")
198+
deadline = time.monotonic() + timeout if timeout is not None else None
199+
futures = [self.submit(func, *args) for args in zip(*iterables)]
200+
201+
# must have generator as a closure so that the submit occurs before first iteration
202+
def generator():
203+
try:
204+
futures.reverse()
205+
while futures:
206+
if deadline is not None:
207+
yield _result_or_cancel(
208+
futures.pop(), timeout=deadline - time.monotonic()
209+
)
210+
else:
211+
yield _result_or_cancel(futures.pop())
212+
finally:
213+
for future in futures:
214+
future.cancel()
215+
216+
return generator()
199217

200218
def shutdown(self, wait=True, *, cancel_futures=False):
201219
with self.__shutdown_lock:
@@ -222,6 +240,16 @@ def __exit__(self, *args):
222240
self.shutdown()
223241

224242

243+
def _result_or_cancel(fut, timeout=None):
244+
try:
245+
try:
246+
return fut.result(timeout)
247+
finally:
248+
fut.cancel()
249+
finally:
250+
del fut # break reference cycle in exceptions
251+
252+
225253
def _format_handle(handle: asyncio.Handle):
226254
cb = getattr(handle, "_callback", None)
227255
if isinstance(getattr(cb, "__self__", None), asyncio.tasks.Task):

tests/test_qthreadexec.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import threading
77
import time
88
import weakref
9-
from concurrent.futures import CancelledError
9+
from concurrent.futures import CancelledError, TimeoutError
10+
from itertools import islice
1011

1112
import pytest
1213

@@ -145,3 +146,103 @@ def task():
145146
assert cancels > 0
146147
else:
147148
assert cancels == 0
149+
150+
151+
def test_map(executor):
152+
"""Basic test of executor map functionality"""
153+
results = list(executor.map(lambda x: x + 1, range(10)))
154+
assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
155+
156+
results = list(executor.map(lambda x, y: x + y, range(10), range(9)))
157+
assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16]
158+
159+
160+
def test_map_timeout(executor):
161+
"""Test that map with timeout raises TimeoutError and cancels futures"""
162+
results = []
163+
164+
def func(x):
165+
nonlocal results
166+
time.sleep(0.05)
167+
results.append(x)
168+
return x
169+
170+
start = time.monotonic()
171+
with pytest.raises(TimeoutError):
172+
list(executor.map(func, range(10), timeout=0.01))
173+
duration = time.monotonic() - start
174+
# this test is flaky on some platforms, so we give it a wide bearth.
175+
assert duration < 0.1
176+
177+
executor.shutdown(wait=True)
178+
# only about half of the tasks should have completed
179+
# because the max number of workers is 5 and the rest of
180+
# the tasks were not started at the time of the cancel.
181+
assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
182+
183+
184+
def test_map_error(executor):
185+
"""Test that map with an exception will raise, and remaining tasks are cancelled"""
186+
results = []
187+
188+
def func(x):
189+
nonlocal results
190+
time.sleep(0.05)
191+
if len(results) == 5:
192+
raise ValueError("Test error")
193+
results.append(x)
194+
return x
195+
196+
with pytest.raises(ValueError):
197+
list(executor.map(func, range(15)))
198+
199+
executor.shutdown(wait=True, cancel_futures=False)
200+
assert len(results) <= 10, "Final 5 at least should have been cancelled"
201+
202+
203+
@pytest.mark.parametrize("cancel", [True, False])
204+
def test_map_shutdown(executor, cancel):
205+
results = []
206+
207+
def func(x):
208+
nonlocal results
209+
time.sleep(0.05)
210+
results.append(x)
211+
return x
212+
213+
# Get the first few results.
214+
# Keep the iterator alive so that it isn't closed when its reference is dropped.
215+
m = executor.map(func, range(15))
216+
values = list(islice(m, 5))
217+
assert values == [0, 1, 2, 3, 4]
218+
219+
executor.shutdown(wait=True, cancel_futures=cancel)
220+
if cancel:
221+
assert len(results) < 15, "Some tasks should have been cancelled"
222+
else:
223+
assert len(results) == 15, "All tasks should have been completed"
224+
m.close()
225+
226+
227+
def test_map_start(executor):
228+
"""Test that map starts tasks immediately, before iterating"""
229+
e = threading.Event()
230+
m = executor.map(lambda x: (e.set(), x), range(1))
231+
e.wait(timeout=0.1)
232+
assert list(m) == [(None, 0)]
233+
234+
235+
def test_map_close(executor):
236+
"""Test that closing a running map cancels all remaining tasks."""
237+
results = []
238+
def func(x):
239+
nonlocal results
240+
time.sleep(0.05)
241+
results.append(x)
242+
return x
243+
m = executor.map(func, range(10))
244+
# must start the generator so that close() has any effect
245+
assert next(m) == 0
246+
m.close()
247+
executor.shutdown(wait=True, cancel_futures=False)
248+
assert len(results) < 10, "Some tasks should have been cancelled"

0 commit comments

Comments
 (0)