|  | 
| 1 | 1 | from __future__ import annotations | 
| 2 | 2 | 
 | 
| 3 | 3 | import sys | 
| 4 |  | -from typing import TYPE_CHECKING | 
|  | 4 | +from typing import TYPE_CHECKING, Unpack | 
| 5 | 5 | 
 | 
| 6 | 6 | from aiohttp import ClientRequest, ClientTimeout | 
|  | 7 | +from aiohttp.client import _RequestOptions | 
| 7 | 8 | from aiohttp.client_proto import ResponseHandler | 
| 8 | 9 | from aiohttp.connector import Connection | 
|  | 10 | +from aiohttp.typedefs import StrOrURL | 
| 9 | 11 | 
 | 
| 10 | 12 | from .. import OperationalError | 
| 11 | 13 | from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED | 
| @@ -44,10 +46,10 @@ def __init__( | 
| 44 | 46 |     ): | 
| 45 | 47 |         self._snowflake_ocsp_mode = snowflake_ocsp_mode | 
| 46 | 48 |         if session_manager is None: | 
| 47 |  | -            logger.debug( | 
| 48 |  | -                "SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance so please verify why it isn't true in the current context" | 
|  | 49 | +            logger.warning( | 
|  | 50 | +                "SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance - verify why it isn't true in the current context" | 
| 49 | 51 |             ) | 
| 50 |  | -            session_manager = SessionManager() | 
|  | 52 | +            session_manager = SessionManagerFactory.get_manager() | 
| 51 | 53 |         self._session_manager = session_manager | 
| 52 | 54 |         if self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN and sys.version_info < ( | 
| 53 | 55 |             3, | 
| @@ -345,24 +347,38 @@ def __init__( | 
| 345 | 347 |             lambda: SessionPool(self) | 
| 346 | 348 |         ) | 
| 347 | 349 | 
 | 
|  | 350 | +    @classmethod | 
|  | 351 | +    def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager: | 
|  | 352 | +        """Build a new manager from *cfg*, optionally overriding fields. | 
|  | 353 | +
 | 
|  | 354 | +        Example:: | 
|  | 355 | +
 | 
|  | 356 | +            no_pool_cfg = conn._http_config.copy_with(use_pooling=False) | 
|  | 357 | +            manager = SessionManager.from_config(no_pool_cfg) | 
|  | 358 | +        """ | 
|  | 359 | + | 
|  | 360 | +        if overrides: | 
|  | 361 | +            cfg = cfg.copy_with(**overrides) | 
|  | 362 | +        return cls(config=cfg) | 
|  | 363 | + | 
| 348 | 364 |     @property | 
| 349 | 365 |     def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]: | 
| 350 | 366 |         return self._cfg.connector_factory | 
| 351 | 367 | 
 | 
| 352 | 368 |     @connector_factory.setter | 
| 353 | 369 |     def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None: | 
| 354 |  | -        self._cfg = self._cfg.copy_with(connector_factory=value) | 
|  | 370 | +        self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value) | 
| 355 | 371 | 
 | 
| 356 | 372 |     def make_session(self) -> aiohttp.ClientSession: | 
| 357 | 373 |         """Create a new aiohttp.ClientSession with configured connector.""" | 
| 358 | 374 |         connector = self._cfg.get_connector( | 
| 359 | 375 |             session_manager=self.clone(), | 
| 360 | 376 |             snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode, | 
| 361 | 377 |         ) | 
| 362 |  | - | 
| 363 | 378 |         return aiohttp.ClientSession( | 
| 364 | 379 |             connector=connector, | 
| 365 | 380 |             trust_env=self._cfg.trust_env, | 
|  | 381 | +            proxy=self.proxy_url, | 
| 366 | 382 |         ) | 
| 367 | 383 | 
 | 
| 368 | 384 |     @contextlib.asynccontextmanager | 
| @@ -425,7 +441,7 @@ def clone( | 
| 425 | 441 |         if connector_factory is not None: | 
| 426 | 442 |             overrides["connector_factory"] = connector_factory | 
| 427 | 443 | 
 | 
| 428 |  | -        return SessionManager.from_config(self._cfg, **overrides) | 
|  | 444 | +        return self.from_config(self._cfg, **overrides) | 
| 429 | 445 | 
 | 
| 430 | 446 | 
 | 
| 431 | 447 | async def request( | 
| @@ -454,3 +470,73 @@ async def request( | 
| 454 | 470 |         use_pooling=use_pooling, | 
| 455 | 471 |         **kwargs, | 
| 456 | 472 |     ) | 
|  | 473 | + | 
|  | 474 | + | 
|  | 475 | +class ProxySessionManager(SessionManager): | 
|  | 476 | +    class SessionWithProxy(aiohttp.ClientSession): | 
|  | 477 | +        async def request( | 
|  | 478 | +            self, | 
|  | 479 | +            method: str, | 
|  | 480 | +            url: StrOrURL, | 
|  | 481 | +            **kwargs: Unpack[_RequestOptions], | 
|  | 482 | +        ): | 
|  | 483 | +            # Inject Host header when proxying | 
|  | 484 | +            try: | 
|  | 485 | +                # respect caller-provided proxy and proxy_headers if any | 
|  | 486 | +                provided_proxy = kwargs.get("proxy") or self._default_proxy | 
|  | 487 | +                provided_proxy_headers = kwargs.get("proxy_headers") | 
|  | 488 | +                if provided_proxy is not None: | 
|  | 489 | +                    authority = urlparse(str(url)).netloc | 
|  | 490 | +                    if provided_proxy_headers is None: | 
|  | 491 | +                        kwargs["proxy_headers"] = {"Host": authority} | 
|  | 492 | +                    elif "Host" not in provided_proxy_headers: | 
|  | 493 | +                        provided_proxy_headers["Host"] = authority | 
|  | 494 | +                    else: | 
|  | 495 | +                        logger.debug( | 
|  | 496 | +                            "Host header was already set - not overriding with netloc at the ClientSession.request method level." | 
|  | 497 | +                        ) | 
|  | 498 | +            except Exception: | 
|  | 499 | +                logger.warning( | 
|  | 500 | +                    "Failed to compute proxy settings for %s", | 
|  | 501 | +                    urlparse(url).hostname, | 
|  | 502 | +                    exc_info=True, | 
|  | 503 | +                ) | 
|  | 504 | +            return await super().request(method, url, **kwargs) | 
|  | 505 | + | 
|  | 506 | +    def make_session(self) -> aiohttp.ClientSession: | 
|  | 507 | +        connector = self._cfg.get_connector( | 
|  | 508 | +            session_manager=self.clone(), | 
|  | 509 | +            snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode, | 
|  | 510 | +        ) | 
|  | 511 | +        # Construct session with base proxy set, request() may override per-URL when bypassing | 
|  | 512 | +        return self.SessionWithProxy( | 
|  | 513 | +            connector=connector, | 
|  | 514 | +            trust_env=self._cfg.trust_env, | 
|  | 515 | +            proxy=self.proxy_url, | 
|  | 516 | +        ) | 
|  | 517 | + | 
|  | 518 | + | 
|  | 519 | +class SessionManagerFactory: | 
|  | 520 | +    @staticmethod | 
|  | 521 | +    def get_manager( | 
|  | 522 | +        config: AioHttpConfig | None = None, **http_config_kwargs | 
|  | 523 | +    ) -> SessionManager: | 
|  | 524 | +        """Return a proxy-aware or plain async SessionManager based on config. | 
|  | 525 | +
 | 
|  | 526 | +        If any explicit proxy parameters are provided (in config or kwargs), | 
|  | 527 | +        return ProxySessionManager; otherwise return the base SessionManager. | 
|  | 528 | +        """ | 
|  | 529 | + | 
|  | 530 | +        def _has_proxy_params(cfg: AioHttpConfig | None, kwargs: dict) -> bool: | 
|  | 531 | +            cfg_keys = ( | 
|  | 532 | +                "proxy_host", | 
|  | 533 | +                "proxy_port", | 
|  | 534 | +            ) | 
|  | 535 | +            in_cfg = any(getattr(cfg, k, None) for k in cfg_keys) if cfg else False | 
|  | 536 | +            in_kwargs = "proxy" in kwargs | 
|  | 537 | +            return in_cfg or in_kwargs | 
|  | 538 | + | 
|  | 539 | +        if _has_proxy_params(config, http_config_kwargs): | 
|  | 540 | +            return ProxySessionManager(config, **http_config_kwargs) | 
|  | 541 | +        else: | 
|  | 542 | +            return SessionManager(config, **http_config_kwargs) | 
0 commit comments