diff --git a/fsspec/implementations/asyn_wrapper.py b/fsspec/implementations/asyn_wrapper.py index 9d106a74c..36db9c1b4 100644 --- a/fsspec/implementations/asyn_wrapper.py +++ b/fsspec/implementations/asyn_wrapper.py @@ -6,7 +6,7 @@ from fsspec.asyn import AsyncFileSystem, running_async -def async_wrapper(func, obj=None): +def async_wrapper(func, obj=None, semaphore=None): """ Wraps a synchronous function to make it awaitable. @@ -16,6 +16,8 @@ def async_wrapper(func, obj=None): The synchronous function to wrap. obj : object, optional The instance to bind the function to, if applicable. + semaphore : asyncio.Semaphore, optional + A semaphore to limit concurrent calls. Returns ------- @@ -25,6 +27,9 @@ def async_wrapper(func, obj=None): @functools.wraps(func) async def wrapper(*args, **kwargs): + if semaphore: + async with semaphore: + return await asyncio.to_thread(func, *args, **kwargs) return await asyncio.to_thread(func, *args, **kwargs) return wrapper @@ -52,6 +57,8 @@ def __init__( asynchronous=None, target_protocol=None, target_options=None, + semaphore=None, + max_concurrent_tasks=None, **kwargs, ): if asynchronous is None: @@ -62,6 +69,7 @@ def __init__( else: self.sync_fs = fsspec.filesystem(target_protocol, **target_options) self.protocol = self.sync_fs.protocol + self.semaphore = semaphore self._wrap_all_sync_methods() @property @@ -83,7 +91,7 @@ def _wrap_all_sync_methods(self): method = getattr(self.sync_fs, method_name) if callable(method) and not inspect.iscoroutinefunction(method): - async_method = async_wrapper(method, obj=self) + async_method = async_wrapper(method, obj=self, semaphore=self.semaphore) setattr(self, f"_{method_name}", async_method) @classmethod diff --git a/fsspec/implementations/tests/test_asyn_wrapper.py b/fsspec/implementations/tests/test_asyn_wrapper.py index 13b766fb1..ee2c8a79a 100644 --- a/fsspec/implementations/tests/test_asyn_wrapper.py +++ b/fsspec/implementations/tests/test_asyn_wrapper.py @@ -1,15 +1,49 @@ import asyncio import os +from itertools import cycle import pytest import fsspec +from fsspec.asyn import AsyncFileSystem from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper from fsspec.implementations.local import LocalFileSystem from .test_local import csv_files, filetexts +class LockedFileSystem(AsyncFileSystem): + """ + A mock file system that simulates a synchronous locking file systems with delays. + """ + + def __init__( + self, + asynchronous: bool = False, + delays=None, + ) -> None: + self.lock = asyncio.Lock() + self.delays = cycle((0.03, 0.01) if delays is None else delays) + + super().__init__(asynchronous=asynchronous) + + async def _cat_file(self, path, start=None, end=None) -> bytes: + await self._simulate_io_operation(path) + return path.encode() + + async def _await_io(self) -> None: + await asyncio.sleep(next(self.delays)) + + async def _simulate_io_operation(self, path) -> None: + await self._check_active() + async with self.lock: + await self._await_io() + + async def _check_active(self) -> None: + if self.lock.locked(): + raise RuntimeError("Concurrent requests!") + + @pytest.mark.asyncio async def test_is_async_default(): fs = fsspec.filesystem("file") @@ -161,3 +195,26 @@ def test_open(tmpdir): ) with of as f: assert f.read() == b"hello" + + +@pytest.mark.asyncio +async def test_semaphore_synchronous(): + fs = AsyncFileSystemWrapper( + LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(1) + ) + + paths = [f"path_{i}" for i in range(1, 3)] + results = await asyncio.gather(*(fs._cat_file(path) for path in paths)) + + assert set(results) == {path.encode() for path in paths} + + +@pytest.mark.asyncio +async def test_deadlock_when_asynchronous(): + fs = AsyncFileSystemWrapper( + LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(3) + ) + paths = [f"path_{i}" for i in range(1, 3)] + + with pytest.raises(RuntimeError, match="Concurrent requests!"): + await asyncio.gather(*(fs._cat_file(path) for path in paths))