Skip to content

Commit

Permalink
Enable TLS protected connections to memcache (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
magnuswatn authored Jan 2, 2023
1 parent 112c43e commit e0615c8
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 7 deletions.
15 changes: 11 additions & 4 deletions aiomcache/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import functools
import re
import sys
from typing import (Any, Awaitable, Callable, Dict, Generic, Optional, Tuple, TypeVar,
Union, overload)
from typing import (Any, Awaitable, Callable, Dict, Generic, Mapping, Optional, Tuple,
TypeVar, Union, overload)

from . import constants as const
from .exceptions import ClientException, ValidationException
Expand Down Expand Up @@ -52,6 +52,7 @@ async def wrapper(self: _Client, *args: _P.args, # type: ignore[misc]
class FlagClient(Generic[_T]):
def __init__(self, host: str, port: int = 11211, *,
pool_size: int = 2, pool_minsize: Optional[int] = None,
conn_args: Optional[Mapping[str, Any]] = None,
get_flag_handler: Optional[_GetFlagHandler[_T]] = None,
set_flag_handler: Optional[_SetFlagHandler[_T]] = None):
"""
Expand All @@ -61,6 +62,9 @@ def __init__(self, host: str, port: int = 11211, *,
:param port: memcached port
:param pool_size: max connection pool size
:param pool_minsize: min connection pool size
:param conn_args: extra arguments passed to
asyncio.open_connection(). For details, see:
https://docs.python.org/3/library/asyncio-stream.html#asyncio.open_connection.
:param get_flag_handler: async method to call to convert flagged
values. Method takes tuple: (value, flags) and should return
processed value or raise ClientException if not supported.
Expand All @@ -72,7 +76,8 @@ def __init__(self, host: str, port: int = 11211, *,
pool_minsize = pool_size

self._pool = MemcachePool(
host, port, minsize=pool_minsize, maxsize=pool_size)
host, port, minsize=pool_minsize, maxsize=pool_size,
conn_args=conn_args)

self._get_flag_handler = get_flag_handler
self._set_flag_handler = set_flag_handler
Expand Down Expand Up @@ -493,6 +498,8 @@ async def flush_all(self, conn: Connection) -> None:

class Client(FlagClient[bytes]):
def __init__(self, host: str, port: int = 11211, *,
pool_size: int = 2, pool_minsize: Optional[int] = None):
pool_size: int = 2, pool_minsize: Optional[int] = None,
conn_args: Optional[Mapping[str, Any]] = None):
super().__init__(host, port, pool_size=pool_size, pool_minsize=pool_minsize,
conn_args=conn_args,
get_flag_handler=None, set_flag_handler=None)
8 changes: 5 additions & 3 deletions aiomcache/pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import NamedTuple, Optional, Set
from typing import Any, Mapping, NamedTuple, Optional, Set

__all__ = ['MemcachePool']

Expand All @@ -10,11 +10,13 @@ class Connection(NamedTuple):


class MemcachePool:
def __init__(self, host: str, port: int, *, minsize: int, maxsize: int):
def __init__(self, host: str, port: int, *, minsize: int, maxsize: int,
conn_args: Optional[Mapping[str, Any]] = None):
self._host = host
self._port = port
self._minsize = minsize
self._maxsize = maxsize
self.conn_args = conn_args or {}
self._pool: asyncio.Queue[Connection] = asyncio.Queue()
self._in_use: Set[Connection] = set()

Expand Down Expand Up @@ -66,7 +68,7 @@ def release(self, conn: Connection) -> None:
async def _create_new_conn(self) -> Optional[Connection]:
if self.size() < self._maxsize:
reader, writer = await asyncio.open_connection(
self._host, self._port)
self._host, self._port, **self.conn_args)
if self.size() < self._maxsize:
return Connection(reader, writer)
else:
Expand Down
39 changes: 39 additions & 0 deletions tests/conn_args_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import ssl
import sys
from asyncio import StreamReader, StreamWriter
from unittest import mock

import pytest

from aiomcache import Client
from .conftest import McacheParams


@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python3.8")
async def test_params_forwarded_from_client() -> None:
client = Client("host", port=11211, conn_args={
"ssl": True, "ssl_handshake_timeout": 20
})

with mock.patch(
"asyncio.open_connection",
return_value=(
mock.create_autospec(StreamReader),
mock.create_autospec(StreamWriter),
),
autospec=True,
) as oc:
await client._pool.acquire()

oc.assert_called_with("host", 11211, ssl=True, ssl_handshake_timeout=20)


async def test_ssl_client_fails_against_plaintext_server(
mcache_params: McacheParams,
) -> None:
client = Client(**mcache_params, conn_args={"ssl": True})
# If SSL was correctly enabled, this should
# fail, since SSL isn't enabled on the memcache
# server.
with pytest.raises(ssl.SSLError):
await client.get(b"key")
2 changes: 2 additions & 0 deletions tests/pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def test_maxsize_greater_than_minsize(mcache_params: McacheParams) -> None
assert isinstance(conn.reader, asyncio.StreamReader)
assert isinstance(conn.writer, asyncio.StreamWriter)
pool.release(conn)
await pool.clear()


async def test_0_minsize(mcache_params: McacheParams) -> None:
Expand All @@ -134,6 +135,7 @@ async def test_0_minsize(mcache_params: McacheParams) -> None:
assert isinstance(conn.reader, asyncio.StreamReader)
assert isinstance(conn.writer, asyncio.StreamWriter)
pool.release(conn)
await pool.clear()


async def test_bad_connection(mcache_params: McacheParams) -> None:
Expand Down

0 comments on commit e0615c8

Please sign in to comment.