Skip to content

Commit 451ac38

Browse files
authored
[#23] reworking exception chaining (#28)
* Changed exception handling * now exceptions are chained (before they were added in `args`) * timeout errors are now chained (before they were not included at all) * in case of dogpiling, all callers are now notified about the error (see issue #23)
1 parent eaf4ac8 commit 451ac38

9 files changed

+127
-31
lines changed

CHANGELOG.rst

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
2.0.0
2+
-----
3+
4+
* Changed exception handling
5+
* now exceptions are chained (before they were added in `args`)
6+
* timeout errors are now chained (before they were not included at all)
7+
* in case of dogpiling, all callers are now notified about the error (see issue #23)
8+
19
1.2.2
210
-----
311

memoize/statuses.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import datetime
55
import logging
66
from asyncio import Future
7-
from typing import Optional, Dict, Awaitable
7+
from typing import Optional, Dict, Awaitable, Union
88

99
from memoize import coerced
1010
from memoize.entry import CacheKey, CacheEntry
@@ -14,8 +14,7 @@ class UpdateStatuses:
1414
def __init__(self, update_lock_timeout: datetime.timedelta = datetime.timedelta(minutes=5)) -> None:
1515
self.logger = logging.getLogger(__name__)
1616
self._update_lock_timeout = update_lock_timeout
17-
# type declaration should not be in comment once we drop py35 support
18-
self._updates_in_progress = {} # type: Dict[CacheKey, Future]
17+
self._updates_in_progress: Dict[CacheKey, Future] = {}
1918

2019
def is_being_updated(self, key: CacheKey) -> bool:
2120
"""Checks if update for given key is in progress. Obtained info is valid until control gets back to IO-loop."""
@@ -49,17 +48,19 @@ def mark_updated(self, key: CacheKey, entry: CacheEntry) -> None:
4948
update = self._updates_in_progress.pop(key)
5049
update.set_result(entry)
5150

52-
def mark_update_aborted(self, key: CacheKey) -> None:
51+
def mark_update_aborted(self, key: CacheKey, exception: Exception) -> None:
5352
"""Informs that update failed to complete.
54-
Calls to 'is_being_updated' will return False until 'mark_being_updated' will be called."""
53+
Calls to 'is_being_updated' will return False until 'mark_being_updated' will be called.
54+
Accepts exception to propagate it across all clients awaiting an update."""
5555
if key not in self._updates_in_progress:
5656
raise ValueError('Key {} is not being updated'.format(key))
5757
update = self._updates_in_progress.pop(key)
58-
update.set_result(None)
58+
update.set_result(exception)
5959

60-
def await_updated(self, key: CacheKey) -> Awaitable[Optional[CacheEntry]]:
60+
def await_updated(self, key: CacheKey) -> Awaitable[Union[CacheEntry, Exception]]:
6161
"""Waits (asynchronously) until update in progress has benn finished.
62-
Returns updated entry or None if update failed/timed-out.
62+
Returns awaitable with the updated entry
63+
(or awaitable with an exception if update failed/timed-out).
6364
Should be called only if 'is_being_updated' returned True (and since then IO-loop has not been lost)."""
6465
if not self.is_being_updated(key):
6566
raise ValueError('Key {} is not being updated'.format(key))

memoize/wrapper.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ async def refresh(actual_entry: Optional[CacheEntry], key: CacheKey,
7777
if actual_entry is None and update_statuses.is_being_updated(key):
7878
logger.debug('As entry expired, waiting for results of concurrent refresh %s', key)
7979
entry = await update_statuses.await_updated(key)
80-
if entry is None:
81-
raise CachedMethodFailedException('Concurrent refresh failed to complete')
80+
if isinstance(entry, Exception):
81+
raise CachedMethodFailedException('Concurrent refresh failed to complete') from entry
8282
return entry
8383
elif actual_entry is not None and update_statuses.is_being_updated(key):
8484
logger.debug('As update point reached but concurrent update already in progress, '
@@ -103,12 +103,12 @@ async def refresh(actual_entry: Optional[CacheEntry], key: CacheKey,
103103
return offered_entry
104104
except (asyncio.TimeoutError, _timeout_error_type()) as e:
105105
logger.debug('Timeout for %s: %s', key, e)
106-
update_statuses.mark_update_aborted(key)
107-
raise CachedMethodFailedException('Refresh timed out')
106+
update_statuses.mark_update_aborted(key, e)
107+
raise CachedMethodFailedException('Refresh timed out') from e
108108
except Exception as e:
109109
logger.debug('Error while refreshing cache for %s: %s', key, e)
110-
update_statuses.mark_update_aborted(key)
111-
raise CachedMethodFailedException('Refresh failed to complete', e)
110+
update_statuses.mark_update_aborted(key, e)
111+
raise CachedMethodFailedException('Refresh failed to complete') from e
112112

113113
@functools.wraps(method)
114114
async def wrapper(*args, **kwargs):

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def prepare_description():
1111

1212
setup(
1313
name='py-memoize',
14-
version='1.2.2',
14+
version='2.0.0',
1515
author='Michal Zmuda',
1616
author_email='[email protected]',
1717
url='https://github.com/DreamLab/memoize',

tests/asynciotests/test_showcase.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,5 @@ async def get_value_or_throw(arg, kwarg=None):
9494
self.assertEqual('ok #2', res4) # value from cache - still relevant
9595
self.assertEqual('ok #2', res5) # stale from cache - refresh in background
9696
self.assertEqual('ok #2', res6) # stale from cache - should be updated but method throws
97-
expected = CachedMethodFailedException('Refresh failed to complete', ValueError('throws #4', ))
98-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
97+
self.assertEqual(str(context.exception), str(CachedMethodFailedException('Refresh failed to complete')))
98+
self.assertEqual(str(context.exception.__cause__), str(ValueError("throws #4")))

tests/asynciotests/test_wrapper_manually_applied_on_asyncio.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from memoize.coerced import _timeout_error_type
12
from tests.py310workaround import fix_python_3_10_compatibility
23

34
fix_python_3_10_compatibility()
@@ -274,8 +275,8 @@ async def get_value(arg, kwarg=None):
274275
await get_value_cached('test1', kwarg='args1')
275276

276277
# then
277-
expected = CachedMethodFailedException('Refresh failed to complete', ValueError('Get lost', ))
278-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
278+
self.assertEqual(str(context.exception), str(CachedMethodFailedException('Refresh failed to complete')))
279+
self.assertEqual(str(context.exception.__cause__), str(ValueError("Get lost")))
279280

280281
@gen_test
281282
async def test_should_throw_exception_on_refresh_timeout(self):
@@ -295,8 +296,8 @@ async def get_value(arg, kwarg=None):
295296
await get_value_cached('test1', kwarg='args1')
296297

297298
# then
298-
expected = CachedMethodFailedException('Refresh timed out')
299-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
299+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
300+
self.assertEqual(context.exception.__cause__.__class__, _timeout_error_type())
300301

301302
@staticmethod
302303
async def _call_thrice(call):

tests/asynciotests/test_wrapper_on_asyncio.py

+68-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from memoize.coerced import _timeout_error_type
12
from tests.py310workaround import fix_python_3_10_compatibility
23

34
fix_python_3_10_compatibility()
@@ -160,6 +161,69 @@ async def get_value(arg, kwarg=None):
160161
self.assertEqual(0, res1)
161162
self.assertEqual(1, res2)
162163

164+
@gen_test
165+
async def test_should_return_exception_for_all_concurrent_callers(self):
166+
# given
167+
value = 0
168+
169+
@memoize()
170+
async def get_value(arg, kwarg=None):
171+
raise ValueError(f'stub{value}')
172+
173+
# when
174+
res1 = get_value('test', kwarg='args1')
175+
res2 = get_value('test', kwarg='args1')
176+
res3 = get_value('test', kwarg='args1')
177+
178+
# then
179+
with self.assertRaises(Exception) as context:
180+
await res1
181+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
182+
self.assertEqual(str(context.exception.__cause__), str(ValueError('stub0')))
183+
184+
with self.assertRaises(Exception) as context:
185+
await res2
186+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
187+
self.assertEqual(str(context.exception.__cause__), str(ValueError('stub0')))
188+
189+
with self.assertRaises(Exception) as context:
190+
await res3
191+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
192+
self.assertEqual(str(context.exception.__cause__), str(ValueError('stub0')))
193+
194+
@gen_test
195+
async def test_should_return_timeout_for_all_concurrent_callers(self):
196+
# given
197+
value = 0
198+
199+
@memoize(configuration=DefaultInMemoryCacheConfiguration(method_timeout=timedelta(milliseconds=1)))
200+
async def get_value(arg, kwarg=None):
201+
await _ensure_asyncio_background_tasks_finished()
202+
time.sleep(.200)
203+
await _ensure_asyncio_background_tasks_finished()
204+
return value
205+
206+
# when
207+
res1 = get_value('test', kwarg='args1')
208+
res2 = get_value('test', kwarg='args1')
209+
res3 = get_value('test', kwarg='args1')
210+
211+
# then
212+
with self.assertRaises(Exception) as context:
213+
await res1
214+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
215+
self.assertEqual(context.exception.__cause__.__class__, _timeout_error_type())
216+
217+
with self.assertRaises(Exception) as context:
218+
await res2
219+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
220+
self.assertEqual(context.exception.__cause__.__class__, _timeout_error_type())
221+
222+
with self.assertRaises(Exception) as context:
223+
await res3
224+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
225+
self.assertEqual(context.exception.__cause__.__class__, _timeout_error_type())
226+
163227
@gen_test
164228
async def test_should_return_same_value_on_constant_key_function(self):
165229
# given
@@ -253,8 +317,8 @@ async def get_value(arg, kwarg=None):
253317
await get_value('test1', kwarg='args1')
254318

255319
# then
256-
expected = CachedMethodFailedException('Refresh failed to complete', ValueError('Get lost', ))
257-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
320+
self.assertEqual(str(context.exception), str(CachedMethodFailedException('Refresh failed to complete')))
321+
self.assertEqual(str(context.exception.__cause__), str(ValueError("Get lost")))
258322

259323
@gen_test
260324
async def test_should_throw_exception_on_refresh_timeout(self):
@@ -272,8 +336,8 @@ async def get_value(arg, kwarg=None):
272336
await get_value('test1', kwarg='args1')
273337

274338
# then
275-
expected = CachedMethodFailedException('Refresh timed out')
276-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
339+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
340+
self.assertEqual(context.exception.__cause__.__class__, _timeout_error_type())
277341

278342
@staticmethod
279343
async def _call_thrice(call):

tests/tornadotests/test_wrapper.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from memoize.coerced import _timeout_error_type
12
from tests.py310workaround import fix_python_3_10_compatibility
23

34
fix_python_3_10_compatibility()
@@ -265,8 +266,8 @@ def get_value(arg, kwarg=None):
265266
yield get_value('test1', kwarg='args1')
266267

267268
# then
268-
expected = CachedMethodFailedException('Refresh failed to complete', ValueError('Get lost', ))
269-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
269+
self.assertEqual(str(context.exception), str(CachedMethodFailedException('Refresh failed to complete')))
270+
self.assertEqual(str(context.exception.__cause__), str(ValueError("Get lost")))
270271

271272
@gen_test
272273
def test_should_throw_exception_on_refresh_timeout(self):
@@ -285,5 +286,6 @@ def get_value(arg, kwarg=None):
285286
yield get_value('test1', kwarg='args1')
286287

287288
# then
288-
expected = CachedMethodFailedException('Refresh timed out')
289-
self.assertEqual(str(expected), str(context.exception)) # ToDo: consider better comparision
289+
self.assertEqual(context.exception.__class__, CachedMethodFailedException)
290+
self.assertEqual(context.exception.__cause__.__class__, _timeout_error_type())
291+

tests/unit/test_statuses.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def test_should_be_mark_as_updated(self):
6262
def test_should_raise_exception_during_mark_update_as_aborted(self):
6363
# given/when/then
6464
with self.assertRaises(ValueError):
65-
self.update_statuses.mark_update_aborted('key')
65+
self.update_statuses.mark_update_aborted('key', Exception('stub'))
6666

6767
def test_should_mark_update_as_aborted(self):
6868
# given
6969
self.update_statuses.mark_being_updated('key')
7070

7171
# when
72-
self.update_statuses.mark_update_aborted('key')
72+
self.update_statuses.mark_update_aborted('key', Exception('stub'))
7373

7474
# then
7575
self.assertFalse(self.update_statuses.is_being_updated('key'))
@@ -105,3 +105,23 @@ async def test_should_await_updated_return_entry(self):
105105
# then
106106
self.assertIsNone(result)
107107
self.assertFalse(self.update_statuses.is_being_updated('key'))
108+
109+
@gen_test
110+
async def test_concurrent_callers_should_all_get_exception_on_aborted_update(self):
111+
# given
112+
self.update_statuses.mark_being_updated('key')
113+
114+
# when
115+
result1 = self.update_statuses.await_updated('key')
116+
result2 = self.update_statuses.await_updated('key')
117+
result3 = self.update_statuses.await_updated('key')
118+
self.update_statuses.mark_update_aborted('key', ValueError('stub'))
119+
result1 = await result1
120+
result2 = await result2
121+
result3 = await result3
122+
123+
# then
124+
self.assertFalse(self.update_statuses.is_being_updated('key'))
125+
self.assertEqual(str(result1), str(ValueError('stub')))
126+
self.assertEqual(str(result2), str(ValueError('stub')))
127+
self.assertEqual(str(result3), str(ValueError('stub')))

0 commit comments

Comments
 (0)