Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2954,14 +2954,18 @@ def release(self, connection):
pass
self._locked = False

def disconnect(self):
"Disconnects all connections in the pool."
def disconnect(self, inuse_connections: bool = True):
"Disconnects either all connections in the pool or just the free connections."
self._checkpid()
try:
if self._in_maintenance:
self._lock.acquire()
self._locked = True
for connection in self._connections:
if inuse_connections:
connections = self._connections
else:
connections = self._get_free_connections()
for connection in connections:
connection.disconnect()
finally:
if self._locked:
Expand Down
33 changes: 28 additions & 5 deletions tests/test_asyncio/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,20 @@ class DummyConnection(Connection):

def __init__(self, **kwargs):
self.kwargs = kwargs
self._connected = False

def repr_pieces(self):
return [("id", id(self)), ("kwargs", self.kwargs)]

async def connect(self):
pass
self._connected = True

async def disconnect(self):
pass
self._connected = False

@property
def is_connected(self):
return self._connected

async def can_read_destructive(self, timeout: float = 0):
return False
Expand Down Expand Up @@ -203,6 +208,20 @@ async def test_repr_contains_db_info_unix(self):
expected = "path=/abc,db=1,client_name=test-client"
assert expected in repr(pool)

async def test_pool_disconnect(self, master_host):
connection_kwargs = {
"host": master_host[0],
"port": master_host[1],
}
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
conn = await pool.get_connection()
await pool.disconnect(inuse_connections=True)
assert not conn.is_connected

await conn.connect()
await pool.disconnect(inuse_connections=False)
assert conn.is_connected


class TestBlockingConnectionPool:
@asynccontextmanager
Expand Down Expand Up @@ -231,17 +250,21 @@ async def test_connection_creation(self, master_host):
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs

async def test_disconnect(self, master_host):
"""A regression test for #1047"""
async def test_pool_disconnect(self, master_host):
connection_kwargs = {
"foo": "bar",
"biz": "baz",
"host": master_host[0],
"port": master_host[1],
}
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
await pool.get_connection()
conn = await pool.get_connection()
await pool.disconnect()
assert not conn.is_connected

await conn.connect()
await pool.disconnect(inuse_connections=False)
assert conn.is_connected

async def test_multiple_connections(self, master_host):
connection_kwargs = {"host": master_host[0], "port": master_host[1]}
Expand Down
38 changes: 37 additions & 1 deletion tests/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ class DummyConnection:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.pid = os.getpid()
self._sock = None

def connect(self):
pass
self._sock = mock.Mock()

def disconnect(self):
self._sock = None

def can_read(self):
return False
Expand Down Expand Up @@ -140,6 +144,21 @@ def test_repr_contains_db_info_unix(self):
expected = "path=/abc,db=1,client_name=test-client"
assert expected in repr(pool)

def test_pool_disconnect(self, master_host):
connection_kwargs = {
"host": master_host[0],
"port": master_host[1],
}
pool = self.get_pool(connection_kwargs=connection_kwargs)

conn = pool.get_connection()
pool.disconnect()
assert not conn._sock

conn.connect()
pool.disconnect(inuse_connections=False)
assert conn._sock


class TestBlockingConnectionPool:
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
Expand Down Expand Up @@ -244,6 +263,23 @@ def test_initialise_pool_with_cache(self, master_host):
)
assert isinstance(pool.get_connection(), CacheProxyConnection)

def test_pool_disconnect(self, master_host):
connection_kwargs = {
"foo": "bar",
"biz": "baz",
"host": master_host[0],
"port": master_host[1],
}
pool = self.get_pool(connection_kwargs=connection_kwargs)

conn = pool.get_connection()
pool.disconnect()
assert not conn._sock

conn.connect()
pool.disconnect(inuse_connections=False)
assert conn._sock


class TestConnectionPoolURLParsing:
def test_hostname(self):
Expand Down