1212
1313import decorator
1414import hexdump
15+ import urllib3
16+ from urllib3 .util .ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
17+ from urllib3 .util .ssl_ import wrap_socket as urllib3_wrap_socket
1518
1619from .compat import (
1720 FileNotFoundError ,
2225 encode_to_bytes ,
2326 text_type ,
2427)
25- from .utils import MocketSocketCore , wrap_ssl_socket , SSL_PROTOCOL
28+ from .utils import SSL_PROTOCOL , MocketSocketCore , wrap_ssl_socket
2629
2730xxh32 = None
2831try :
4144except ImportError :
4245 pyopenssl_override = False
4346
47+
4448true_socket = socket .socket
4549true_create_connection = socket .create_connection
4650true_gethostbyname = socket .gethostbyname
5054true_ssl_socket = ssl .SSLSocket
5155true_ssl_context = ssl .SSLContext
5256true_inet_pton = socket .inet_pton
57+ true_urllib3_wrap_socket = urllib3_wrap_socket
58+ true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
5359
5460
5561class SuperFakeSSLContext (object ):
@@ -87,6 +93,8 @@ def load_default_certs(*args, **kwargs):
8793
8894 @staticmethod
8995 def wrap_socket (sock = sock , * args , ** kwargs ):
96+ sock .kwargs = kwargs
97+ sock ._secure_socket = True
9098 return sock
9199
92100 def wrap_bio (self , incoming , outcoming , * args , ** kwargs ):
@@ -126,6 +134,7 @@ class MocketSocket(object):
126134 compression = lambda s : ssl .OP_NO_COMPRESSION
127135 _mode = None
128136 _bufsize = None
137+ _secure_socket = False
129138
130139 def __init__ (
131140 self , family = socket .AF_INET , type = socket .SOCK_STREAM , proto = 0 , * args , ** kwargs
@@ -139,6 +148,7 @@ def __init__(
139148 self .type = int (type )
140149 self .proto = int (proto )
141150 self ._truesocket_recording_dir = None
151+ self .kwargs = kwargs
142152
143153 sock = kwargs .get ("sock" )
144154 if sock is not None :
@@ -175,6 +185,9 @@ def settimeout(self, timeout):
175185 except AttributeError :
176186 pass
177187
188+ def getsockopt (self , level , optname , buflen = None ):
189+ return socket .SOCK_STREAM
190+
178191 def do_handshake (self ):
179192 pass
180193
@@ -309,6 +322,18 @@ def true_sendall(self, data, *args, **kwargs):
309322 host , port = Mocket ._address
310323 host = true_gethostbyname (host )
311324
325+ if isinstance (self .true_socket , true_socket ) and self ._secure_socket :
326+ try :
327+ self = MocketSocket (sock = self )
328+ except TypeError :
329+ ssl_context = self .kwargs .get ("ssl_context" )
330+ server_hostname = self .kwargs .get ("server_hostname" )
331+ self .true_socket = true_ssl_context .wrap_socket (
332+ self = ssl_context ,
333+ sock = self .true_socket ,
334+ server_hostname = server_hostname ,
335+ )
336+
312337 try :
313338 self .true_socket .connect ((host , port ))
314339 except (OSError , socket .error , ValueError ):
@@ -342,7 +367,7 @@ def true_sendall(self, data, *args, **kwargs):
342367 )
343368 )
344369
345- # response back to .sendall() which writes it to the mocket socket and flush the BytesIO
370+ # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
346371 return encoded_response
347372
348373 def send (self , data , * args , ** kwargs ): # pragma: no cover
@@ -438,11 +463,20 @@ def enable(namespace=None, truesocket_recording_dir=None):
438463 (2 , 1 , 6 , "" , (host , port ))
439464 ]
440465 ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = FakeSSLContext .wrap_socket
441- ssl .SSLSocket = ssl .__dict__ ["SSLSocket" ] = MocketSocket
466+ # ssl.SSLSocket = ssl.__dict__["SSLSocket"] = MocketSocket
442467 ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = FakeSSLContext
443468 socket .inet_pton = socket .__dict__ ["inet_pton" ] = lambda family , ip : byte_type (
444469 "\x7f \x00 \x00 \x01 " , "utf-8"
445470 )
471+ urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ [
472+ "wrap_socket"
473+ ] = FakeSSLContext .wrap_socket
474+ urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
475+ "ssl_wrap_socket"
476+ ] = FakeSSLContext .wrap_socket
477+ urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
478+ "ssl_wrap_socket"
479+ ] = FakeSSLContext .wrap_socket
446480 if pyopenssl_override :
447481 # Take out the pyopenssl version - use the default implementation
448482 extract_from_urllib3 ()
@@ -459,9 +493,18 @@ def disable():
459493 socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = true_gethostbyname
460494 socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = true_getaddrinfo
461495 ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = true_ssl_wrap_socket
462- ssl .SSLSocket = ssl .__dict__ ["SSLSocket" ] = true_ssl_socket
496+ # ssl.SSLSocket = ssl.__dict__["SSLSocket"] = true_ssl_socket
463497 ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = true_ssl_context
464498 socket .inet_pton = socket .__dict__ ["inet_pton" ] = true_inet_pton
499+ urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ [
500+ "wrap_socket"
501+ ] = true_urllib3_wrap_socket
502+ urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
503+ "ssl_wrap_socket"
504+ ] = true_urllib3_ssl_wrap_socket
505+ urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
506+ "ssl_wrap_socket"
507+ ] = true_urllib3_ssl_wrap_socket
465508 Mocket .reset ()
466509 if pyopenssl_override :
467510 # Put the pyopenssl version back in place
0 commit comments