Skip to content
43 changes: 37 additions & 6 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class _CacheParameters(TypedDict):
class _CacheItem(Generic[_R]):
fut: "asyncio.Future[_R]"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking further on this... if the _CacheItem now holds a reference to the task itself, does it even need a separate Future for the result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented the refactorings I mentioned as possible here #708

later_call: Optional[asyncio.Handle]
waiters: int
task: "asyncio.Task[_R]"

def cancel(self) -> None:
if self.later_call is not None:
Expand Down Expand Up @@ -192,6 +194,15 @@ def _task_done_callback(

fut.set_result(task.result())

def _handle_cancelled_error(self, key: Hashable, cache_item: "_CacheItem") -> None:
# Called when a waiter is cancelled.
# If this is the last waiter and the underlying task is not done,
# cancel the underlying task and remove the cache entry.
if cache_item.waiters == 1 and not cache_item.task.done():
cache_item.cancel() # Cancel TTL expiration
cache_item.task.cancel() # Cancel the running coroutine
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we either need to pop from __tasks here or just remove __tasks entirely. I'd prefer the latter but would like a second opinion before i implement

self.__cache.pop(key, None) # Remove from cache

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")
Expand All @@ -205,8 +216,20 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
if cache_item is not None:
self._cache_hit(key)
if not cache_item.fut.done():
return await asyncio.shield(cache_item.fut)

# Each logical waiter increments waiters on entry.
cache_item.waiters += 1

try:
# All waiters await the same future.
return await asyncio.shield(cache_item.fut)
except asyncio.CancelledError:
# If a waiter is cancelled, handle possible last-waiter cleanup.
self._handle_cancelled_error(key, cache_item)
raise
finally:
# Each logical waiter decrements waiters on exit (normal or cancelled).
cache_item.waiters -= 1
# If the future is already done, just return the result.
return cache_item.fut.result()

fut = loop.create_future()
Expand All @@ -215,14 +238,22 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))

self.__cache[key] = _CacheItem(fut, None)
cache_item = _CacheItem(fut, None, 1, task)
self.__cache[key] = cache_item

if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
dropped_key, cache_item = self.__cache.popitem(last=False)
cache_item.cancel()
dropped_key, dropped_cache_item = self.__cache.popitem(last=False)
dropped_cache_item.cancel()

self._cache_miss(key)
return await asyncio.shield(fut)

try:
return await asyncio.shield(fut)
except asyncio.CancelledError:
self._handle_cancelled_error(key, cache_item)
raise
finally:
cache_item.waiters -= 1

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
Expand Down
58 changes: 58 additions & 0 deletions tests/test_cancel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import asyncio

import pytest

from async_lru import alru_cache


@pytest.mark.parametrize("num_to_cancel", [0, 1, 2, 3])
async def test_cancel(num_to_cancel: int) -> None:
cache_item_task_finished = False

@alru_cache
async def coro(val: int) -> int:
# I am a long running coro function
nonlocal cache_item_task_finished
await asyncio.sleep(2)
cache_item_task_finished = True
return val

# create 3 tasks for the cached function using the same key
tasks = [asyncio.create_task(coro(1)) for _ in range(3)]

# force the event loop to run once so the tasks can begin
await asyncio.sleep(0)

# maybe cancel some tasks
for i in range(num_to_cancel):
tasks[i].cancel()

# allow enough time for the non-cancelled tasks to complete
await asyncio.sleep(3)

# check state
assert cache_item_task_finished == (num_to_cancel < 3)


@pytest.mark.asyncio
async def test_cancel_single_waiter_triggers_handle_cancelled_error():
# This test ensures the _handle_cancelled_error path (waiters == 1) is exercised.
cache_item_task_finished = False

@alru_cache
async def coro(val: int) -> int:
nonlocal cache_item_task_finished
await asyncio.sleep(2)
cache_item_task_finished = True
return val

task = asyncio.create_task(coro(42))
await asyncio.sleep(0)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

# The underlying coroutine should be cancelled, so the flag should remain False
assert cache_item_task_finished is False
Loading