Skip to content

Commit 1fd1064

Browse files
committed
Fixing sync BlockingConnectionPool's disconnect method to follow the definition of the interface
1 parent 1c474e5 commit 1fd1064

File tree

3 files changed

+76
-8
lines changed

3 files changed

+76
-8
lines changed

redis/connection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,14 +2954,18 @@ def release(self, connection):
29542954
pass
29552955
self._locked = False
29562956

2957-
def disconnect(self):
2958-
"Disconnects all connections in the pool."
2957+
def disconnect(self, inuse_connections: bool = True):
2958+
"Disconnects either all connections in the pool or just the free connections."
29592959
self._checkpid()
29602960
try:
29612961
if self._in_maintenance:
29622962
self._lock.acquire()
29632963
self._locked = True
2964-
for connection in self._connections:
2964+
if inuse_connections:
2965+
connections = self._connections
2966+
else:
2967+
connections = self._get_free_connections()
2968+
for connection in connections:
29652969
connection.disconnect()
29662970
finally:
29672971
if self._locked:

tests/test_asyncio/test_connection_pool.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,20 @@ class DummyConnection(Connection):
9595

9696
def __init__(self, **kwargs):
9797
self.kwargs = kwargs
98+
self._connected = False
9899

99100
def repr_pieces(self):
100101
return [("id", id(self)), ("kwargs", self.kwargs)]
101102

102103
async def connect(self):
103-
pass
104+
self._connected = True
104105

105106
async def disconnect(self):
106-
pass
107+
self._connected = False
108+
109+
@property
110+
def is_connected(self):
111+
return self._connected
107112

108113
async def can_read_destructive(self, timeout: float = 0):
109114
return False
@@ -203,6 +208,22 @@ async def test_repr_contains_db_info_unix(self):
203208
expected = "path=/abc,db=1,client_name=test-client"
204209
assert expected in repr(pool)
205210

211+
async def test_pool_disconnect(self, master_host):
212+
connection_kwargs = {
213+
"host": master_host[0],
214+
"port": master_host[1],
215+
}
216+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
217+
conn = await pool.get_connection()
218+
await pool.disconnect(inuse_connections=True)
219+
assert not conn.is_connected
220+
221+
await conn.connect()
222+
await pool.disconnect(inuse_connections=False)
223+
assert conn.is_connected
224+
225+
226+
206227

207228
class TestBlockingConnectionPool:
208229
@asynccontextmanager
@@ -231,7 +252,7 @@ async def test_connection_creation(self, master_host):
231252
assert isinstance(connection, DummyConnection)
232253
assert connection.kwargs == connection_kwargs
233254

234-
async def test_disconnect(self, master_host):
255+
async def test_pool_disconnect(self, master_host):
235256
"""A regression test for #1047"""
236257
connection_kwargs = {
237258
"foo": "bar",
@@ -240,8 +261,13 @@ async def test_disconnect(self, master_host):
240261
"port": master_host[1],
241262
}
242263
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
243-
await pool.get_connection()
264+
conn = await pool.get_connection()
244265
await pool.disconnect()
266+
assert not conn.is_connected
267+
268+
await conn.connect()
269+
await pool.disconnect(inuse_connections=False)
270+
assert conn.is_connected
245271

246272
async def test_multiple_connections(self, master_host):
247273
connection_kwargs = {"host": master_host[0], "port": master_host[1]}

tests/test_connection_pool.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from threading import Thread
66
from unittest import mock
77

8+
from mock import Mock
89
import pytest
910
import redis
1011
from redis.cache import CacheConfig
@@ -29,9 +30,13 @@ class DummyConnection:
2930
def __init__(self, **kwargs):
3031
self.kwargs = kwargs
3132
self.pid = os.getpid()
33+
self._sock = None
3234

3335
def connect(self):
34-
pass
36+
self._sock = Mock()
37+
38+
def disconnect(self):
39+
self._sock = None
3540

3641
def can_read(self):
3742
return False
@@ -140,6 +145,22 @@ def test_repr_contains_db_info_unix(self):
140145
expected = "path=/abc,db=1,client_name=test-client"
141146
assert expected in repr(pool)
142147

148+
def test_pool_disconnect(self, master_host):
149+
"""A regression test for #1047"""
150+
connection_kwargs = {
151+
"host": master_host[0],
152+
"port": master_host[1],
153+
}
154+
pool = self.get_pool(connection_kwargs=connection_kwargs)
155+
156+
conn = pool.get_connection()
157+
pool.disconnect()
158+
assert not conn._sock
159+
160+
conn.connect()
161+
pool.disconnect(inuse_connections=False)
162+
assert conn._sock
163+
143164

144165
class TestBlockingConnectionPool:
145166
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
@@ -244,6 +265,23 @@ def test_initialise_pool_with_cache(self, master_host):
244265
)
245266
assert isinstance(pool.get_connection(), CacheProxyConnection)
246267

268+
def test_pool_disconnect(self, master_host):
269+
"""A regression test for #1047"""
270+
connection_kwargs = {
271+
"foo": "bar",
272+
"biz": "baz",
273+
"host": master_host[0],
274+
"port": master_host[1],
275+
}
276+
pool = self.get_pool(connection_kwargs=connection_kwargs)
277+
278+
conn = pool.get_connection()
279+
pool.disconnect()
280+
assert not conn._sock
281+
282+
conn.connect()
283+
pool.disconnect(inuse_connections=False)
284+
assert conn._sock
247285

248286
class TestConnectionPoolURLParsing:
249287
def test_hostname(self):

0 commit comments

Comments
 (0)