Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/2596.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed proxy authorization headers not being passed when reusing a connection, which caused 407 (Proxy authentication required) errors
-- by :user:`GLeurquin`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ Georges Dubus
Greg Holt
Gregory Haynes
Grigoriy Soldatov
Guillaume Leurquin
Gus Goulart
Gustavo Carneiro
Günther Jena
Expand Down
49 changes: 29 additions & 20 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,30 @@ def _available_connections(self, key: "ConnectionKey") -> int:

return total_remain

def _update_proxy_auth_header_and_build_proxy_req(
self, req: ClientRequest
) -> ClientRequestBase:
"""Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests."""
url = req.proxy
assert url is not None
headers = req.proxy_headers or CIMultiDict[str]()
headers[hdrs.HOST] = req.headers[hdrs.HOST]
proxy_req = ClientRequestBase(
hdrs.METH_GET,
url,
headers=headers,
auth=req.proxy_auth,
loop=self._loop,
ssl=req.ssl,
)
auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
if auth is not None:
if not req.is_ssl():
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
else:
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
return proxy_req

async def connect(
self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout"
) -> Connection:
Expand All @@ -558,12 +582,16 @@ async def connect(
if (conn := await self._get(key, traces)) is not None:
# If we do not have to wait and we can get a connection from the pool
# we can avoid the timeout ceil logic and directly return the connection
if req.proxy:
self._update_proxy_auth_header_and_build_proxy_req(req)
return conn

async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
if self._available_connections(key) <= 0:
await self._wait_for_available_connection(key, traces)
if (conn := await self._get(key, traces)) is not None:
if req.proxy:
self._update_proxy_auth_header_and_build_proxy_req(req)
return conn

placeholder = cast(
Expand Down Expand Up @@ -1453,32 +1481,13 @@ async def _create_direct_connection(
async def _create_proxy_connection(
self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout"
) -> tuple[asyncio.BaseTransport, ResponseHandler]:
headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers
headers[hdrs.HOST] = req.headers[hdrs.HOST]

url = req.proxy
assert url is not None
proxy_req = ClientRequestBase(
hdrs.METH_GET,
url,
headers=headers,
auth=req.proxy_auth,
loop=self._loop,
ssl=req.ssl,
)
proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req)

# create connection to proxy server
transport, proto = await self._create_direct_connection(
proxy_req, [], timeout, client_error=ClientProxyConnectionError
)

auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
if auth is not None:
if not req.is_ssl():
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
else:
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth

if req.is_ssl():
self._warn_about_tls_in_tls(transport, req)

Expand Down
90 changes: 90 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tests of http client with custom Connector
import asyncio
import contextlib
import gc
import hashlib
import platform
Expand All @@ -16,6 +17,7 @@
from unittest import mock

import pytest
from multidict import CIMultiDict
from pytest_mock import MockerFixture
from yarl import URL

Expand All @@ -25,6 +27,7 @@
ClientSession,
ClientTimeout,
connector as connector_module,
hdrs,
web,
)
from aiohttp.abc import ResolveResult
Expand Down Expand Up @@ -3299,6 +3302,93 @@ async def test_connect_reuseconn_tracing(
await conn.close()


@pytest.mark.parametrize(
"test_case,wait_for_con,expect_proxy_auth_header",
[
("use_proxy_with_embedded_auth", False, True),
("use_proxy_with_auth_headers", True, True),
("use_proxy_no_auth", False, False),
("dont_use_proxy", False, False),
],
)
async def test_connect_reuse_proxy_headers( # type: ignore[misc]
loop: asyncio.AbstractEventLoop,
make_client_request: _RequestMaker,
test_case: str,
wait_for_con: bool,
expect_proxy_auth_header: bool,
) -> None:
proto = create_mocked_conn(loop)
proto.is_connected.return_value = True

if test_case != "dont_use_proxy":
proxy = (
URL("http://user:[email protected]")
if test_case == "use_proxy_with_embedded_auth"
else URL("http://example.com")
)
proxy_headers = (
CIMultiDict({hdrs.AUTHORIZATION: "Basic dXNlcjpwYXNzd29yZA=="})
if test_case == "use_proxy_with_auth_headers"
else None
)
else:
proxy = None
proxy_headers = None
key = ConnectionKey(
"localhost",
80,
False,
True,
proxy,
None,
hash(tuple(proxy_headers.items())) if proxy_headers else None,
)
req = make_client_request(
"GET",
URL("http://localhost:80"),
loop=loop,
response_class=mock.Mock(),
proxy=proxy,
proxy_headers=proxy_headers,
)

conn = aiohttp.BaseConnector(limit=1)

async def _create_con(*args: Any, **kwargs: Any) -> None:
conn._conns[key] = deque([(proto, loop.time())])

with contextlib.ExitStack() as stack:
if wait_for_con:
# Simulate no available connections
stack.enter_context(
mock.patch.object(
conn, "_available_connections", autospec=True, return_value=0
)
)
# Upon waiting for a connection, populate _conns with our proto,
# mocking a connection becoming immediately available
stack.enter_context(
mock.patch.object(
conn,
"_wait_for_available_connection",
autospec=True,
side_effect=_create_con,
)
)
else:
await _create_con()
# Call function to test
conn2 = await conn.connect(req, [], ClientTimeout())
conn2.release()
await conn.close()

if expect_proxy_auth_header:
assert req.headers[hdrs.PROXY_AUTHORIZATION] == "Basic dXNlcjpwYXNzd29yZA=="
else:
assert hdrs.PROXY_AUTHORIZATION not in req.headers


async def test_connect_with_limit_and_limit_per_host(
loop: asyncio.AbstractEventLoop,
key: ConnectionKey,
Expand Down
Loading