Skip to content

Commit 0c5c07a

Browse files
authored
Fix regression (#239)
1 parent 501088e commit 0c5c07a

File tree

4 files changed

+66
-25
lines changed

4 files changed

+66
-25
lines changed

mocket/mocket.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import socket
1111
import ssl
1212
from datetime import datetime, timedelta
13-
from io import BytesIO
1413
from json.decoder import JSONDecodeError
14+
from typing import Optional, Tuple
1515

1616
import urllib3
1717
from urllib3.connection import match_hostname as urllib3_match_hostname
@@ -27,6 +27,7 @@
2727
from .utils import (
2828
SSL_PROTOCOL,
2929
MocketMode,
30+
MocketSocketCore,
3031
get_mocketize,
3132
hexdump,
3233
hexload,
@@ -73,15 +74,15 @@
7374

7475

7576
class SuperFakeSSLContext:
76-
"""For Python 3.6"""
77+
"""For Python 3.6 and newer."""
7778

7879
class FakeSetter(int):
7980
def __set__(self, *args):
8081
pass
8182

8283
minimum_version = FakeSetter()
8384
options = FakeSetter()
84-
verify_mode = FakeSetter(ssl.CERT_NONE)
85+
verify_mode = FakeSetter()
8586

8687

8788
class FakeSSLContext(SuperFakeSSLContext):
@@ -177,6 +178,7 @@ class MocketSocket:
177178
_secure_socket = False
178179
_did_handshake = False
179180
_sent_non_empty_bytes = False
181+
_io = None
180182

181183
def __init__(
182184
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
@@ -200,10 +202,18 @@ def __exit__(self, exc_type, exc_val, exc_tb):
200202
self.close()
201203

202204
@property
203-
def fd(self):
204-
if self._fd is None:
205-
self._fd = BytesIO()
206-
return self._fd
205+
def io(self):
206+
if self._io is None:
207+
self._io = MocketSocketCore((self._host, self._port))
208+
return self._io
209+
210+
def fileno(self):
211+
address = (self._host, self._port)
212+
r_fd, _ = Mocket.get_pair(address)
213+
if not r_fd:
214+
r_fd, w_fd = os.pipe()
215+
Mocket.set_pair(address, (r_fd, w_fd))
216+
return r_fd
207217

208218
def gettimeout(self):
209219
return self.timeout
@@ -264,19 +274,14 @@ def unwrap(self):
264274
def write(self, data):
265275
return self.send(encode_to_bytes(data))
266276

267-
def fileno(self):
268-
if self.true_socket:
269-
return self.true_socket.fileno()
270-
return self.fd.fileno()
271-
272277
def connect(self, address):
273278
self._address = self._host, self._port = address
274279
Mocket._address = address
275280

276281
def makefile(self, mode="r", bufsize=-1):
277282
self._mode = mode
278283
self._bufsize = bufsize
279-
return self.fd
284+
return self.io
280285

281286
def get_entry(self, data):
282287
return Mocket.get_entry(self._host, self._port, data)
@@ -292,13 +297,13 @@ def sendall(self, data, entry=None, *args, **kwargs):
292297
response = self.true_sendall(data, *args, **kwargs)
293298

294299
if response is not None:
295-
self.fd.seek(0)
296-
self.fd.write(response)
297-
self.fd.truncate()
298-
self.fd.seek(0)
300+
self.io.seek(0)
301+
self.io.write(response)
302+
self.io.truncate()
303+
self.io.seek(0)
299304

300305
def read(self, buffersize):
301-
rv = self.fd.read(buffersize)
306+
rv = self.io.read(buffersize)
302307
if rv:
303308
self._sent_non_empty_bytes = True
304309
if self._did_handshake and not self._sent_non_empty_bytes:
@@ -315,6 +320,9 @@ def recv_into(self, buffer, buffersize=None, flags=None):
315320
return len(data)
316321

317322
def recv(self, buffersize, flags=None):
323+
r_fd, _ = Mocket.get_pair((self._host, self._port))
324+
if r_fd:
325+
return os.read(r_fd, buffersize)
318326
data = self.read(buffersize)
319327
if data:
320328
return data
@@ -416,8 +424,8 @@ def true_sendall(self, data, *args, **kwargs):
416424

417425
def send(self, data, *args, **kwargs): # pragma: no cover
418426
entry = self.get_entry(data)
419-
kwargs["entry"] = entry
420427
if not entry or (entry and self._entry != entry):
428+
kwargs["entry"] = entry
421429
self.sendall(data, *args, **kwargs)
422430
else:
423431
req = Mocket.last_request()
@@ -441,12 +449,29 @@ def do_nothing(*args, **kwargs):
441449

442450

443451
class Mocket:
452+
_socket_pairs = {}
444453
_address = (None, None)
445454
_entries = collections.defaultdict(list)
446455
_requests = []
447456
_namespace = text_type(id(_entries))
448457
_truesocket_recording_dir = None
449458

459+
@classmethod
460+
def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
461+
"""
462+
Given the id() of the caller, return a pair of file descriptors
463+
as a tuple of two integers: (<read_fd>, <write_fd>)
464+
"""
465+
return cls._socket_pairs.get(address, (None, None))
466+
467+
@classmethod
468+
def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
469+
"""
470+
Store a pair of file descriptors under the key `id_`
471+
as a tuple of two integers: (<read_fd>, <write_fd>)
472+
"""
473+
cls._socket_pairs[address] = pair
474+
450475
@classmethod
451476
def register(cls, *entries):
452477
for entry in entries:
@@ -467,6 +492,10 @@ def collect(cls, data):
467492

468493
@classmethod
469494
def reset(cls):
495+
for r_fd, w_fd in cls._socket_pairs.values():
496+
os.close(r_fd)
497+
os.close(w_fd)
498+
cls._socket_pairs = {}
470499
cls._entries = collections.defaultdict(list)
471500
cls._requests = []
472501

mocket/mockhttp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def can_handle(self, data):
201201
"""
202202
try:
203203
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
204-
method, path, version = self._parse_requestline(requestline)
204+
method, path, _ = self._parse_requestline(requestline)
205205
except ValueError:
206206
return self is getattr(Mocket, "_last_entry", None)
207207

mocket/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import binascii
4+
import io
5+
import os
46
import ssl
57
from typing import TYPE_CHECKING, Any, Callable, ClassVar
68

@@ -14,6 +16,21 @@
1416
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
1517

1618

19+
class MocketSocketCore(io.BytesIO):
20+
def __init__(self, address) -> None:
21+
self._address = address
22+
super().__init__()
23+
24+
def write(self, content):
25+
from mocket import Mocket
26+
27+
super().write(content)
28+
29+
_, w_fd = Mocket.get_pair(self._address)
30+
if w_fd:
31+
os.write(w_fd, content)
32+
33+
1734
def hexdump(binary_string: bytes) -> str:
1835
r"""
1936
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))

tests/main/test_asyncio.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import glob
33
import json
44
import socket
5-
import sys
65
import tempfile
76

87
import aiohttp
@@ -45,10 +44,6 @@ async def test_asyncio_connection():
4544

4645

4746
@pytest.mark.asyncio
48-
@pytest.mark.skipif(
49-
sys.version_info < (3, 11),
50-
reason="Looks like https://github.com/aio-libs/aiohttp/issues/5582",
51-
)
5247
@async_mocketize
5348
async def test_aiohttp():
5449
url = "https://bar.foo/"

0 commit comments

Comments
 (0)