Skip to content

Commit 3edce14

Browse files
authored
Merge pull request #219 from mindflayer/external-pr
External contribution
2 parents 2bca049 + a5af5c3 commit 3edce14

File tree

5 files changed

+137
-26
lines changed

5 files changed

+137
-26
lines changed

mocket/mocket.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@
4949
except ImportError:
5050
pyopenssl_override = False
5151

52+
try: # pragma: no cover
53+
from aiohttp import TCPConnector
54+
55+
aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
56+
except (ImportError, AttributeError):
57+
aiohttp_make_ssl_context_cache_clear = None
58+
5259

5360
true_socket = socket.socket
5461
true_create_connection = socket.create_connection
@@ -85,6 +92,7 @@ class FakeSSLContext(SuperFakeSSLContext):
8592
"load_verify_locations",
8693
"set_alpn_protocols",
8794
"set_ciphers",
95+
"set_default_verify_paths",
8896
)
8997
sock = None
9098
post_handshake_auth = None
@@ -180,6 +188,8 @@ def __init__(
180188
self.type = int(type)
181189
self.proto = int(proto)
182190
self._truesocket_recording_dir = None
191+
self._did_handshake = False
192+
self._sent_non_empty_bytes = False
183193
self.kwargs = kwargs
184194

185195
def __str__(self):
@@ -218,7 +228,7 @@ def getsockopt(level, optname, buflen=None):
218228
return socket.SOCK_STREAM
219229

220230
def do_handshake(self):
221-
pass
231+
self._did_handshake = True
222232

223233
def getpeername(self):
224234
return self._address
@@ -257,6 +267,8 @@ def write(self, data):
257267

258268
@staticmethod
259269
def fileno():
270+
if Mocket.r_fd is not None:
271+
return Mocket.r_fd
260272
Mocket.r_fd, Mocket.w_fd = os.pipe()
261273
return Mocket.r_fd
262274

@@ -292,10 +304,21 @@ def sendall(self, data, entry=None, *args, **kwargs):
292304
self.fd.seek(0)
293305

294306
def read(self, buffersize):
295-
return self.fd.read(buffersize)
307+
rv = self.fd.read(buffersize)
308+
if rv:
309+
self._sent_non_empty_bytes = True
310+
if self._did_handshake and not self._sent_non_empty_bytes:
311+
raise ssl.SSLWantReadError("The operation did not complete (read)")
312+
return rv
296313

297314
def recv_into(self, buffer, buffersize=None, flags=None):
298-
return buffer.write(self.read(buffersize))
315+
if hasattr(buffer, "write"):
316+
return buffer.write(self.read(buffersize))
317+
# buffer is a memoryview
318+
data = self.read(buffersize)
319+
if data:
320+
buffer[: len(data)] = data
321+
return len(data)
299322

300323
def recv(self, buffersize, flags=None):
301324
if Mocket.r_fd and Mocket.w_fd:
@@ -455,8 +478,12 @@ def collect(cls, data):
455478

456479
@classmethod
457480
def reset(cls):
458-
cls.r_fd = None
459-
cls.w_fd = None
481+
if cls.r_fd is not None:
482+
os.close(cls.r_fd)
483+
cls.r_fd = None
484+
if cls.w_fd is not None:
485+
os.close(cls.w_fd)
486+
cls.w_fd = None
460487
cls._entries = collections.defaultdict(list)
461488
cls._requests = []
462489

@@ -527,6 +554,8 @@ def enable(namespace=None, truesocket_recording_dir=None):
527554
if pyopenssl_override: # pragma: no cover
528555
# Take out the pyopenssl version - use the default implementation
529556
extract_from_urllib3()
557+
if aiohttp_make_ssl_context_cache_clear: # pragma: no cover
558+
aiohttp_make_ssl_context_cache_clear()
530559

531560
@staticmethod
532561
def disable():
@@ -563,6 +592,8 @@ def disable():
563592
if pyopenssl_override: # pragma: no cover
564593
# Put the pyopenssl version back in place
565594
inject_into_urllib3()
595+
if aiohttp_make_ssl_context_cache_clear: # pragma: no cover
596+
aiohttp_make_ssl_context_cache_clear()
566597

567598
@classmethod
568599
def get_namespace(cls):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dynamic = ["version"]
3636
[project.optional-dependencies]
3737
test = [
3838
"pre-commit",
39+
"psutil",
3940
"pytest",
4041
"pytest-cov",
4142
"pytest-asyncio",

tests/main/test_mocket.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from unittest import TestCase
77
from unittest.mock import patch
88

9+
import httpx
10+
import psutil
911
import pytest
1012

1113
from mocket import Mocket, MocketEntry, Mocketizer, mocketize
@@ -190,3 +192,19 @@ def test_patch(
190192
):
191193
method_patch.return_value = "foo"
192194
assert os.getcwd() == "foo"
195+
196+
197+
@pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test")
198+
@pytest.mark.asyncio
199+
async def test_no_dangling_fds():
200+
url = "http://httpbin.local/ip"
201+
202+
proc = psutil.Process(os.getpid())
203+
204+
prev_num_fds = proc.num_fds()
205+
206+
async with Mocketizer(strict_mode=False):
207+
async with httpx.AsyncClient() as client:
208+
await client.get(url)
209+
210+
assert proc.num_fds() == prev_num_fds

tests/tests38/test_http_aiohttp.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import json
22
from unittest import IsolatedAsyncioTestCase
33

4-
import httpx
54
import pytest
65

76
from mocket.async_mocket import async_mocketize
8-
from mocket.mocket import Mocket
7+
from mocket.mocket import Mocket, Mocketizer
98
from mocket.mockhttp import Entry
109
from mocket.plugins.httpretty import HTTPretty, async_httprettified
1110

@@ -46,6 +45,23 @@ async def test_http_session(self):
4645

4746
self.assertEqual(len(Mocket.request_list()), 2)
4847

48+
@async_httprettified
49+
async def test_httprettish_session(self):
50+
HTTPretty.register_uri(
51+
HTTPretty.GET,
52+
self.target_url,
53+
body=json.dumps(dict(origin="127.0.0.1")),
54+
)
55+
56+
async with aiohttp.ClientSession(timeout=self.timeout) as session:
57+
async with session.get(self.target_url) as get_response:
58+
assert get_response.status == 200
59+
assert await get_response.text() == '{"origin": "127.0.0.1"}'
60+
61+
class AioHttpsEntryTestCase(IsolatedAsyncioTestCase):
62+
timeout = aiohttp.ClientTimeout(total=3)
63+
target_url = "https://httpbin.localhost/anything/"
64+
4965
@async_mocketize
5066
async def test_https_session(self):
5167
body = "asd" * 100
@@ -67,7 +83,14 @@ async def test_https_session(self):
6783

6884
self.assertEqual(len(Mocket.request_list()), 2)
6985

70-
@pytest.mark.xfail
86+
@async_mocketize
87+
async def test_no_verify(self):
88+
Entry.single_register(Entry.GET, self.target_url, status=404)
89+
90+
async with aiohttp.ClientSession(timeout=self.timeout) as session:
91+
async with session.get(self.target_url, ssl=False) as get_response:
92+
assert get_response.status == 404
93+
7194
@async_httprettified
7295
async def test_httprettish_session(self):
7396
HTTPretty.register_uri(
@@ -81,21 +104,15 @@ async def test_httprettish_session(self):
81104
assert get_response.status == 200
82105
assert await get_response.text() == '{"origin": "127.0.0.1"}'
83106

84-
85-
class HttpxEntryTestCase(IsolatedAsyncioTestCase):
86-
target_url = "http://httpbin.local/ip"
87-
88-
@async_httprettified
89-
async def test_httprettish_httpx_session(self):
90-
expected_response = {"origin": "127.0.0.1"}
91-
92-
HTTPretty.register_uri(
93-
HTTPretty.GET,
94-
self.target_url,
95-
body=json.dumps(expected_response),
96-
)
97-
98-
async with httpx.AsyncClient() as client:
99-
response = await client.get(self.target_url)
100-
assert response.status_code == 200
101-
assert response.json() == expected_response
107+
@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)')
108+
async def test_mocked_https_request_after_unmocked_https_request(self):
109+
async with aiohttp.ClientSession(timeout=self.timeout) as session:
110+
response = await session.get(self.target_url + "real", ssl=False)
111+
assert response.status == 200
112+
113+
async with Mocketizer(None):
114+
Entry.single_register(Entry.GET, self.target_url + "mocked", status=404)
115+
async with aiohttp.ClientSession(timeout=self.timeout) as session:
116+
response = await session.get(self.target_url + "mocked", ssl=False)
117+
assert response.status == 404
118+
self.assertEqual(len(Mocket.request_list()), 1)

tests/tests38/test_http_httpx.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import json
2+
from unittest import IsolatedAsyncioTestCase
3+
4+
import httpx
5+
6+
from mocket.plugins.httpretty import HTTPretty, async_httprettified
7+
8+
9+
class HttpxEntryTestCase(IsolatedAsyncioTestCase):
10+
target_url = "http://httpbin.local/ip"
11+
12+
@async_httprettified
13+
async def test_httprettish_httpx_session(self):
14+
expected_response = {"origin": "127.0.0.1"}
15+
16+
HTTPretty.register_uri(
17+
HTTPretty.GET,
18+
self.target_url,
19+
body=json.dumps(expected_response),
20+
)
21+
22+
async with httpx.AsyncClient() as client:
23+
response = await client.get(self.target_url)
24+
assert response.status_code == 200
25+
assert response.json() == expected_response
26+
27+
28+
class HttpxHttpsEntryTestCase(IsolatedAsyncioTestCase):
29+
target_url = "https://httpbin.local/ip"
30+
31+
@async_httprettified
32+
async def test_httprettish_httpx_session(self):
33+
expected_response = {"origin": "127.0.0.1"}
34+
35+
HTTPretty.register_uri(
36+
HTTPretty.GET,
37+
self.target_url,
38+
body=json.dumps(expected_response),
39+
)
40+
41+
async with httpx.AsyncClient() as client:
42+
response = await client.get(self.target_url)
43+
assert response.status_code == 200
44+
assert response.json() == expected_response

0 commit comments

Comments
 (0)