|
49 | 49 | except ImportError: |
50 | 50 | pyopenssl_override = False |
51 | 51 |
|
| 52 | +try: # pragma: no cover |
| 53 | + from aiohttp import TCPConnector |
| 54 | + |
| 55 | + aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear |
| 56 | +except (ImportError, AttributeError): |
| 57 | + aiohttp_make_ssl_context_cache_clear = None |
| 58 | + |
52 | 59 |
|
53 | 60 | true_socket = socket.socket |
54 | 61 | true_create_connection = socket.create_connection |
@@ -85,6 +92,7 @@ class FakeSSLContext(SuperFakeSSLContext): |
85 | 92 | "load_verify_locations", |
86 | 93 | "set_alpn_protocols", |
87 | 94 | "set_ciphers", |
| 95 | + "set_default_verify_paths", |
88 | 96 | ) |
89 | 97 | sock = None |
90 | 98 | post_handshake_auth = None |
@@ -180,6 +188,8 @@ def __init__( |
180 | 188 | self.type = int(type) |
181 | 189 | self.proto = int(proto) |
182 | 190 | self._truesocket_recording_dir = None |
| 191 | + self._did_handshake = False |
| 192 | + self._sent_non_empty_bytes = False |
183 | 193 | self.kwargs = kwargs |
184 | 194 |
|
185 | 195 | def __str__(self): |
@@ -218,7 +228,7 @@ def getsockopt(level, optname, buflen=None): |
218 | 228 | return socket.SOCK_STREAM |
219 | 229 |
|
220 | 230 | def do_handshake(self): |
221 | | - pass |
| 231 | + self._did_handshake = True |
222 | 232 |
|
223 | 233 | def getpeername(self): |
224 | 234 | return self._address |
@@ -257,6 +267,8 @@ def write(self, data): |
257 | 267 |
|
258 | 268 | @staticmethod |
259 | 269 | def fileno(): |
| 270 | + if Mocket.r_fd is not None: |
| 271 | + return Mocket.r_fd |
260 | 272 | Mocket.r_fd, Mocket.w_fd = os.pipe() |
261 | 273 | return Mocket.r_fd |
262 | 274 |
|
@@ -292,10 +304,21 @@ def sendall(self, data, entry=None, *args, **kwargs): |
292 | 304 | self.fd.seek(0) |
293 | 305 |
|
294 | 306 | def read(self, buffersize): |
295 | | - return self.fd.read(buffersize) |
| 307 | + rv = self.fd.read(buffersize) |
| 308 | + if rv: |
| 309 | + self._sent_non_empty_bytes = True |
| 310 | + if self._did_handshake and not self._sent_non_empty_bytes: |
| 311 | + raise ssl.SSLWantReadError("The operation did not complete (read)") |
| 312 | + return rv |
296 | 313 |
|
297 | 314 | def recv_into(self, buffer, buffersize=None, flags=None): |
298 | | - return buffer.write(self.read(buffersize)) |
| 315 | + if hasattr(buffer, "write"): |
| 316 | + return buffer.write(self.read(buffersize)) |
| 317 | + # buffer is a memoryview |
| 318 | + data = self.read(buffersize) |
| 319 | + if data: |
| 320 | + buffer[: len(data)] = data |
| 321 | + return len(data) |
299 | 322 |
|
300 | 323 | def recv(self, buffersize, flags=None): |
301 | 324 | if Mocket.r_fd and Mocket.w_fd: |
@@ -455,8 +478,12 @@ def collect(cls, data): |
455 | 478 |
|
456 | 479 | @classmethod |
457 | 480 | def reset(cls): |
458 | | - cls.r_fd = None |
459 | | - cls.w_fd = None |
| 481 | + if cls.r_fd is not None: |
| 482 | + os.close(cls.r_fd) |
| 483 | + cls.r_fd = None |
| 484 | + if cls.w_fd is not None: |
| 485 | + os.close(cls.w_fd) |
| 486 | + cls.w_fd = None |
460 | 487 | cls._entries = collections.defaultdict(list) |
461 | 488 | cls._requests = [] |
462 | 489 |
|
@@ -527,6 +554,8 @@ def enable(namespace=None, truesocket_recording_dir=None): |
527 | 554 | if pyopenssl_override: # pragma: no cover |
528 | 555 | # Take out the pyopenssl version - use the default implementation |
529 | 556 | extract_from_urllib3() |
| 557 | + if aiohttp_make_ssl_context_cache_clear: # pragma: no cover |
| 558 | + aiohttp_make_ssl_context_cache_clear() |
530 | 559 |
|
531 | 560 | @staticmethod |
532 | 561 | def disable(): |
@@ -563,6 +592,8 @@ def disable(): |
563 | 592 | if pyopenssl_override: # pragma: no cover |
564 | 593 | # Put the pyopenssl version back in place |
565 | 594 | inject_into_urllib3() |
| 595 | + if aiohttp_make_ssl_context_cache_clear: # pragma: no cover |
| 596 | + aiohttp_make_ssl_context_cache_clear() |
566 | 597 |
|
567 | 598 | @classmethod |
568 | 599 | def get_namespace(cls): |
|
0 commit comments