diff --git a/src/requests/adapters.py b/src/requests/adapters.py index 9a58b16025..b4cc4e6505 100644 --- a/src/requests/adapters.py +++ b/src/requests/adapters.py @@ -472,6 +472,17 @@ def get_connection_with_tls_context(self, request, verify, proxies=None, cert=No ) except ValueError as e: raise InvalidURL(e, request=request) + + # If request has a Host header, use its value as the server_hostname for SNI + server_hostname = None + if 'Host' in request.headers: + server_hostname = request.headers['Host'] + elif 'host' in host_params: + server_hostname = host_params['host'] + + if server_hostname: + pool_kwargs['server_hostname'] = server_hostname + if proxy: proxy = prepend_scheme_if_needed(proxy, "http") proxy_url = parse_url(proxy) diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 6c55d5a130..080c8075a6 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -1,4 +1,8 @@ +import unittest +from unittest import mock + import requests.adapters +from requests.models import PreparedRequest def test_request_url_trims_leading_path_separators(): @@ -6,3 +10,61 @@ def test_request_url_trims_leading_path_separators(): a = requests.adapters.HTTPAdapter() p = requests.Request(method="GET", url="http://127.0.0.1:10000//v:h").prepare() assert "/v:h" == a.request_url(p, {}) + + +class TestHTTPAdapter(unittest.TestCase): + def test_get_connection_with_tls_context(self): + """Test that get_connection_with_tls_context correctly passes server_hostname.""" + adapter = requests.adapters.HTTPAdapter() + + # Case 1: Request with a custom 'Host' header (custom SNI) + mock_request = mock.MagicMock(spec=PreparedRequest) + mock_request.url = "https://1.2.3.4:5678/api/check" + mock_request.headers = {'Host': 'custom.hostname.com'} + + # Mock the pool managers + adapter.poolmanager = mock.MagicMock() + adapter.proxy_manager = {} + + # Test without proxy + adapter.get_connection_with_tls_context(mock_request, verify=True) + + # Verify that poolmanager.connection_from_host was called with server_hostname + called_args = adapter.poolmanager.connection_from_host.call_args[1] + self.assertIn('pool_kwargs', called_args) + self.assertIn('server_hostname', called_args['pool_kwargs']) + self.assertEqual(called_args['pool_kwargs']['server_hostname'], 'custom.hostname.com') + + # Case 2: Request without a 'Host' header (default hostname) + mock_request = mock.MagicMock(spec=PreparedRequest) + mock_request.url = "https://example.com:443/api/check" + mock_request.headers = {} + + adapter.get_connection_with_tls_context(mock_request, verify=True) + + # Verify that poolmanager.connection_from_host was called with the correct server_hostname + called_args = adapter.poolmanager.connection_from_host.call_args[1] + self.assertIn('pool_kwargs', called_args) + self.assertIn('server_hostname', called_args['pool_kwargs']) + self.assertEqual(called_args['pool_kwargs']['server_hostname'], 'example.com') + + # Case 3: Request with proxy and custom 'Host' header + mock_request = mock.MagicMock(spec=PreparedRequest) + mock_request.url = "https://1.2.3.4:5678/api/check" + mock_request.headers = {'Host': 'custom.hostname.com'} + + # Setup proxy + proxy = "http://127.0.0.1:8080" + proxies = {"https": proxy} + + # Mock the proxy manager + mock_proxy_manager = mock.MagicMock() + adapter.proxy_manager_for = mock.MagicMock(return_value=mock_proxy_manager) + + adapter.get_connection_with_tls_context(mock_request, verify=True, proxies=proxies) + + # Verify that proxy_manager.connection_from_host was called with server_hostname + called_args = mock_proxy_manager.connection_from_host.call_args[1] + self.assertIn('pool_kwargs', called_args) + self.assertIn('server_hostname', called_args['pool_kwargs']) + self.assertEqual(called_args['pool_kwargs']['server_hostname'], 'custom.hostname.com')