Skip to content

Commit d1d3eb0

Browse files
committed
Add semaphore to AsyncFileSystemWrapper
- Initialize a semaphore in AsyncFileSystemWrapper if asynchronous mode is disabled. This ensures that concurrent calls are managed safely, preventing potential deadlocks in systems that cannot manage concurrent requests. - Added LockedFileSystem class to test this case.
1 parent c46db87 commit d1d3eb0

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

fsspec/implementations/asyn_wrapper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fsspec.asyn import AsyncFileSystem, running_async
77

88

9-
def async_wrapper(func, obj=None):
9+
def async_wrapper(func, obj=None, semaphore=None):
1010
"""
1111
Wraps a synchronous function to make it awaitable.
1212
@@ -16,6 +16,8 @@ def async_wrapper(func, obj=None):
1616
The synchronous function to wrap.
1717
obj : object, optional
1818
The instance to bind the function to, if applicable.
19+
semaphore : asyncio.Semaphore, optional
20+
A semaphore to limit concurrent calls.
1921
2022
Returns
2123
-------
@@ -25,6 +27,9 @@ def async_wrapper(func, obj=None):
2527

2628
@functools.wraps(func)
2729
async def wrapper(*args, **kwargs):
30+
if semaphore:
31+
async with semaphore:
32+
return await asyncio.to_thread(func, *args, **kwargs)
2833
return await asyncio.to_thread(func, *args, **kwargs)
2934

3035
return wrapper
@@ -62,6 +67,7 @@ def __init__(
6267
else:
6368
self.sync_fs = fsspec.filesystem(target_protocol, **target_options)
6469
self.protocol = self.sync_fs.protocol
70+
self.semaphore = asyncio.Semaphore(1) if not asynchronous else None
6571
self._wrap_all_sync_methods()
6672

6773
@property
@@ -83,7 +89,7 @@ def _wrap_all_sync_methods(self):
8389

8490
method = getattr(self.sync_fs, method_name)
8591
if callable(method) and not inspect.iscoroutinefunction(method):
86-
async_method = async_wrapper(method, obj=self)
92+
async_method = async_wrapper(method, obj=self, semaphore=self.semaphore)
8793
setattr(self, f"_{method_name}", async_method)
8894

8995
@classmethod

fsspec/implementations/tests/test_asyn_wrapper.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,49 @@
11
import asyncio
22
import os
3+
from itertools import cycle
34

45
import pytest
56

67
import fsspec
8+
from fsspec.asyn import AsyncFileSystem
79
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
810
from fsspec.implementations.local import LocalFileSystem
911

1012
from .test_local import csv_files, filetexts
1113

1214

15+
class LockedFileSystem(AsyncFileSystem):
16+
"""
17+
A mock file system that simulates a synchronous locking file systems with delays.
18+
"""
19+
20+
def __init__(
21+
self,
22+
asynchronous: bool = False,
23+
delays: tuple[float, ...] | None = None,
24+
) -> None:
25+
self.lock = asyncio.Lock()
26+
self.delays = cycle((0.03, 0.01) if delays is None else delays)
27+
28+
super().__init__(asynchronous=asynchronous)
29+
30+
async def _cat_file(self, path, start=None, end=None) -> bytes:
31+
await self._simulate_io_operation(path)
32+
return path.encode()
33+
34+
async def _await_io(self) -> None:
35+
await asyncio.sleep(next(self.delays))
36+
37+
async def _simulate_io_operation(self, path) -> None:
38+
await self._check_active()
39+
async with self.lock:
40+
await self._await_io()
41+
42+
async def _check_active(self) -> None:
43+
if self.lock.locked():
44+
raise RuntimeError("Concurrent requests!")
45+
46+
1347
@pytest.mark.asyncio
1448
async def test_is_async_default():
1549
fs = fsspec.filesystem("file")
@@ -161,3 +195,22 @@ def test_open(tmpdir):
161195
)
162196
with of as f:
163197
assert f.read() == b"hello"
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_semaphore_synchronous():
202+
fs = AsyncFileSystemWrapper(LockedFileSystem(), asynchronous=False)
203+
204+
paths = [f"path_{i}" for i in range(1, 3)]
205+
results = await asyncio.gather(*(fs._cat_file(path) for path in paths))
206+
207+
assert set(results) == {path.encode() for path in paths}
208+
209+
210+
@pytest.mark.asyncio
211+
async def test_deadlock_when_asynchronous():
212+
fs = AsyncFileSystemWrapper(LockedFileSystem(), asynchronous=True)
213+
paths = [f"path_{i}" for i in range(1, 3)]
214+
215+
with pytest.raises(RuntimeError, match="Concurrent requests!"):
216+
await asyncio.gather(*(fs._cat_file(path) for path in paths))

0 commit comments

Comments
 (0)