From af2d3328d0786b690b3710ad3672442c66a4dddd Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 17 Oct 2025 14:02:45 +0300 Subject: [PATCH 1/2] Fixing sync BlockingConnectionPool's disconnect method to follow the definition of the interface --- redis/connection.py | 10 ++++-- tests/test_asyncio/test_connection_pool.py | 35 +++++++++++++++++--- tests/test_connection_pool.py | 37 +++++++++++++++++++++- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 35e2bdf9ce..389529a1a7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2954,14 +2954,18 @@ def release(self, connection): pass self._locked = False - def disconnect(self): - "Disconnects all connections in the pool." + def disconnect(self, inuse_connections: bool = True): + "Disconnects either all connections in the pool or just the free connections." self._checkpid() try: if self._in_maintenance: self._lock.acquire() self._locked = True - for connection in self._connections: + if inuse_connections: + connections = self._connections + else: + connections = self._get_free_connections() + for connection in connections: connection.disconnect() finally: if self._locked: diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index cb3dac9604..b4dcc4a7b0 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -95,15 +95,20 @@ class DummyConnection(Connection): def __init__(self, **kwargs): self.kwargs = kwargs + self._connected = False def repr_pieces(self): return [("id", id(self)), ("kwargs", self.kwargs)] async def connect(self): - pass + self._connected = True async def disconnect(self): - pass + self._connected = False + + @property + def is_connected(self): + return self._connected async def can_read_destructive(self, timeout: float = 0): return False @@ -203,6 +208,22 @@ async def test_repr_contains_db_info_unix(self): expected = "path=/abc,db=1,client_name=test-client" assert expected in repr(pool) + async def test_pool_disconnect(self, master_host): + connection_kwargs = { + "host": master_host[0], + "port": master_host[1], + } + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + conn = await pool.get_connection() + await pool.disconnect(inuse_connections=True) + assert not conn.is_connected + + await conn.connect() + await pool.disconnect(inuse_connections=False) + assert conn.is_connected + + + class TestBlockingConnectionPool: @asynccontextmanager @@ -231,8 +252,7 @@ async def test_connection_creation(self, master_host): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs - async def test_disconnect(self, master_host): - """A regression test for #1047""" + async def test_pool_disconnect(self, master_host): connection_kwargs = { "foo": "bar", "biz": "baz", @@ -240,8 +260,13 @@ async def test_disconnect(self, master_host): "port": master_host[1], } async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - await pool.get_connection() + conn = await pool.get_connection() await pool.disconnect() + assert not conn.is_connected + + await conn.connect() + await pool.disconnect(inuse_connections=False) + assert conn.is_connected async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 2397f15600..eea7ca2d7c 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -29,9 +29,13 @@ class DummyConnection: def __init__(self, **kwargs): self.kwargs = kwargs self.pid = os.getpid() + self._sock = None def connect(self): - pass + self._sock = mock.Mock() + + def disconnect(self): + self._sock = None def can_read(self): return False @@ -140,6 +144,21 @@ def test_repr_contains_db_info_unix(self): expected = "path=/abc,db=1,client_name=test-client" assert expected in repr(pool) + def test_pool_disconnect(self, master_host): + connection_kwargs = { + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + + conn = pool.get_connection() + pool.disconnect() + assert not conn._sock + + conn.connect() + pool.disconnect(inuse_connections=False) + assert conn._sock + class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): @@ -244,6 +263,22 @@ def test_initialise_pool_with_cache(self, master_host): ) assert isinstance(pool.get_connection(), CacheProxyConnection) + def test_pool_disconnect(self, master_host): + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + + conn = pool.get_connection() + pool.disconnect() + assert not conn._sock + + conn.connect() + pool.disconnect(inuse_connections=False) + assert conn._sock class TestConnectionPoolURLParsing: def test_hostname(self): From 162d07b213ec0ad57bf64ee8ea6a7ee343bc695c Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 21 Oct 2025 13:42:55 +0300 Subject: [PATCH 2/2] Fixing linter errors --- tests/test_asyncio/test_connection_pool.py | 2 -- tests/test_connection_pool.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index b4dcc4a7b0..e658c14188 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -223,8 +223,6 @@ async def test_pool_disconnect(self, master_host): assert conn.is_connected - - class TestBlockingConnectionPool: @asynccontextmanager async def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index eea7ca2d7c..7365c6ff13 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -280,6 +280,7 @@ def test_pool_disconnect(self, master_host): pool.disconnect(inuse_connections=False) assert conn._sock + class TestConnectionPoolURLParsing: def test_hostname(self): pool = redis.ConnectionPool.from_url("redis://my.host")