Skip to content

Commit 647f3af

Browse files
[async] Applied #2451 to async code - test passing, ProxySessionManager for async with SessionWithProxy
1 parent c6153e6 commit 647f3af

15 files changed

+509
-59
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Error,
2222
OperationalError,
2323
ProgrammingError,
24-
proxy,
2524
)
2625

2726
from .._query_context_cache import QueryContextCache
@@ -80,6 +79,7 @@
8079
from ._session_manager import (
8180
AioHttpConfig,
8281
SessionManager,
82+
SessionManagerFactory,
8383
SnowflakeSSLConnectorFactory,
8484
)
8585
from ._telemetry import TelemetryClient
@@ -191,10 +191,6 @@ async def __open_connection(self):
191191
use_numpy=self._numpy, support_negative_year=self._support_negative_year
192192
)
193193

194-
proxy.set_proxies(
195-
self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password
196-
)
197-
198194
self._rest = SnowflakeRestful(
199195
host=self.host,
200196
port=self.port,
@@ -1014,13 +1010,17 @@ async def connect(self, **kwargs) -> None:
10141010
else:
10151011
self.__config(**self._conn_parameters)
10161012

1017-
self._http_config = AioHttpConfig(
1013+
self._http_config: AioHttpConfig = AioHttpConfig(
10181014
connector_factory=SnowflakeSSLConnectorFactory(),
10191015
use_pooling=not self.disable_request_pooling,
1016+
proxy_host=self.proxy_host,
1017+
proxy_port=self.proxy_port,
1018+
proxy_user=self.proxy_user,
1019+
proxy_password=self.proxy_password,
10201020
snowflake_ocsp_mode=self._ocsp_mode(),
10211021
trust_env=True, # Required for proxy support via environment variables
10221022
)
1023-
self._session_manager = SessionManager(self._http_config)
1023+
self._session_manager = SessionManagerFactory.get_manager(self._http_config)
10241024

10251025
if self.enable_connection_diag:
10261026
raise NotImplementedError(

src/snowflake/connector/aio/_network.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import TYPE_CHECKING, Any, AsyncGenerator
1111

1212
import OpenSSL.SSL
13-
from urllib3.util.url import parse_url
1413

1514
from ..compat import FORBIDDEN, OK, UNAUTHORIZED, urlencode, urlparse, urlsplit
1615
from ..constants import (
@@ -79,7 +78,11 @@
7978
)
8079
from ..time_util import TimeoutBackoffCtx
8180
from ._description import CLIENT_NAME
82-
from ._session_manager import SessionManager, SnowflakeSSLConnectorFactory
81+
from ._session_manager import (
82+
SessionManager,
83+
SessionManagerFactory,
84+
SnowflakeSSLConnectorFactory,
85+
)
8386

8487
if TYPE_CHECKING:
8588
from snowflake.connector.aio import SnowflakeConnection
@@ -145,15 +148,12 @@ def __init__(
145148
session_manager = (
146149
connection._session_manager
147150
if (connection and connection._session_manager)
148-
else SessionManager(connector_factory=SnowflakeSSLConnectorFactory())
151+
else SessionManagerFactory.get_manager(
152+
connector_factory=SnowflakeSSLConnectorFactory()
153+
)
149154
)
150155
self._session_manager = session_manager
151156

152-
if self._connection and self._connection.proxy_host:
153-
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
154-
else:
155-
self._get_proxy_headers = lambda _: None
156-
157157
async def close(self) -> None:
158158
if hasattr(self, "_token"):
159159
del self._token
@@ -737,7 +737,6 @@ async def _request_exec(
737737
headers=headers,
738738
data=input_data,
739739
timeout=aiohttp.ClientTimeout(socket_timeout),
740-
proxy_headers=self._get_proxy_headers(full_url),
741740
)
742741
try:
743742
if raw_ret.status == OK:

src/snowflake/connector/aio/_result_batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
raise_failed_request_error,
1414
raise_okta_unauthorized_error,
1515
)
16-
from snowflake.connector.aio._session_manager import SessionManager
16+
from snowflake.connector.aio._session_manager import SessionManagerFactory
1717
from snowflake.connector.aio._time_util import TimerContextManager
1818
from snowflake.connector.arrow_context import ArrowConverterContext
1919
from snowflake.connector.backoff_policies import exponential_backoff
@@ -261,7 +261,9 @@ async def download_chunk(http_session):
261261
logger.debug(
262262
f"downloading result batch id: {self.id} with new session through local session manager"
263263
)
264-
local_session_manager = SessionManager(use_pooling=False)
264+
local_session_manager = SessionManagerFactory.get_manager(
265+
use_pooling=False
266+
)
265267
async with local_session_manager.use_session() as session:
266268
response, content, encoding = await download_chunk(session)
267269

src/snowflake/connector/aio/_session_manager.py

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from typing import TYPE_CHECKING
55

66
from aiohttp import ClientRequest, ClientTimeout
7+
from aiohttp.client import _RequestOptions
78
from aiohttp.client_proto import ResponseHandler
89
from aiohttp.connector import Connection
10+
from aiohttp.typedefs import StrOrURL
911

1012
from .. import OperationalError
1113
from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED
@@ -14,6 +16,8 @@
1416

1517
if TYPE_CHECKING:
1618
from aiohttp.tracing import Trace
19+
from typing import Unpack
20+
from aiohttp.client import _RequestContextManager
1721

1822
import abc
1923
import collections
@@ -44,10 +48,10 @@ def __init__(
4448
):
4549
self._snowflake_ocsp_mode = snowflake_ocsp_mode
4650
if session_manager is None:
47-
logger.debug(
48-
"SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance so please verify why it isn't true in the current context"
51+
logger.warning(
52+
"SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance - verify why it isn't true in the current context"
4953
)
50-
session_manager = SessionManager()
54+
session_manager = SessionManagerFactory.get_manager()
5155
self._session_manager = session_manager
5256
if self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN and sys.version_info < (
5357
3,
@@ -345,24 +349,38 @@ def __init__(
345349
lambda: SessionPool(self)
346350
)
347351

352+
@classmethod
353+
def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager:
354+
"""Build a new manager from *cfg*, optionally overriding fields.
355+
356+
Example::
357+
358+
no_pool_cfg = conn._http_config.copy_with(use_pooling=False)
359+
manager = SessionManager.from_config(no_pool_cfg)
360+
"""
361+
362+
if overrides:
363+
cfg = cfg.copy_with(**overrides)
364+
return cls(config=cfg)
365+
348366
@property
349367
def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]:
350368
return self._cfg.connector_factory
351369

352370
@connector_factory.setter
353371
def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None:
354-
self._cfg = self._cfg.copy_with(connector_factory=value)
372+
self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value)
355373

356374
def make_session(self) -> aiohttp.ClientSession:
357375
"""Create a new aiohttp.ClientSession with configured connector."""
358376
connector = self._cfg.get_connector(
359377
session_manager=self.clone(),
360378
snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode,
361379
)
362-
363380
return aiohttp.ClientSession(
364381
connector=connector,
365382
trust_env=self._cfg.trust_env,
383+
proxy=self.proxy_url,
366384
)
367385

368386
@contextlib.asynccontextmanager
@@ -425,7 +443,7 @@ def clone(
425443
if connector_factory is not None:
426444
overrides["connector_factory"] = connector_factory
427445

428-
return SessionManager.from_config(self._cfg, **overrides)
446+
return self.from_config(self._cfg, **overrides)
429447

430448

431449
async def request(
@@ -454,3 +472,82 @@ async def request(
454472
use_pooling=use_pooling,
455473
**kwargs,
456474
)
475+
476+
477+
class ProxySessionManager(SessionManager):
478+
class SessionWithProxy(aiohttp.ClientSession):
479+
if sys.version_info >= (3, 11) and TYPE_CHECKING:
480+
481+
def request(
482+
self,
483+
method: str,
484+
url: StrOrURL,
485+
**kwargs: Unpack[_RequestOptions],
486+
) -> _RequestContextManager: ...
487+
488+
else:
489+
490+
def request(
491+
self, method: str, url: StrOrURL, **kwargs: Any
492+
) -> _RequestContextManager:
493+
"""Perform HTTP request."""
494+
# Inject Host header when proxying
495+
try:
496+
# respect caller-provided proxy and proxy_headers if any
497+
provided_proxy = kwargs.get("proxy") or self._default_proxy
498+
provided_proxy_headers = kwargs.get("proxy_headers")
499+
if provided_proxy is not None:
500+
authority = urlparse(str(url)).netloc
501+
if provided_proxy_headers is None:
502+
kwargs["proxy_headers"] = {"Host": authority}
503+
elif "Host" not in provided_proxy_headers:
504+
provided_proxy_headers["Host"] = authority
505+
else:
506+
logger.debug(
507+
"Host header was already set - not overriding with netloc at the ClientSession.request method level."
508+
)
509+
except Exception:
510+
logger.warning(
511+
"Failed to compute proxy settings for %s",
512+
urlparse(url).hostname,
513+
exc_info=True,
514+
)
515+
return super().request(method, url, **kwargs)
516+
517+
def make_session(self) -> aiohttp.ClientSession:
518+
connector = self._cfg.get_connector(
519+
session_manager=self.clone(),
520+
snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode,
521+
)
522+
# Construct session with base proxy set, request() may override per-URL when bypassing
523+
return self.SessionWithProxy(
524+
connector=connector,
525+
trust_env=self._cfg.trust_env,
526+
proxy=self.proxy_url,
527+
)
528+
529+
530+
class SessionManagerFactory:
531+
@staticmethod
532+
def get_manager(
533+
config: AioHttpConfig | None = None, **http_config_kwargs
534+
) -> SessionManager:
535+
"""Return a proxy-aware or plain async SessionManager based on config.
536+
537+
If any explicit proxy parameters are provided (in config or kwargs),
538+
return ProxySessionManager; otherwise return the base SessionManager.
539+
"""
540+
541+
def _has_proxy_params(cfg: AioHttpConfig | None, kwargs: dict) -> bool:
542+
cfg_keys = (
543+
"proxy_host",
544+
"proxy_port",
545+
)
546+
in_cfg = any(getattr(cfg, k, None) for k in cfg_keys) if cfg else False
547+
in_kwargs = "proxy" in kwargs
548+
return in_cfg or in_kwargs
549+
550+
if _has_proxy_params(config, http_config_kwargs):
551+
return ProxySessionManager(config, **http_config_kwargs)
552+
else:
553+
return SessionManager(config, **http_config_kwargs)

src/snowflake/connector/aio/_storage_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..encryption_util import SnowflakeEncryptionUtil
1616
from ..errors import RequestExceedMaxRetryError
1717
from ..storage_client import SnowflakeStorageClient as SnowflakeStorageClientSync
18-
from ._session_manager import SessionManager
18+
from ._session_manager import SessionManagerFactory
1919

2020
if TYPE_CHECKING: # pragma: no cover
2121
from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential
@@ -205,7 +205,9 @@ async def _send_request_with_retry(
205205
# SessionManager on the fly, if code ends up here, since we probably do not care about losing
206206
# proxy or HTTP setup.
207207
logger.debug("storage client request with new session")
208-
session_manager = SessionManager(use_pooling=False)
208+
session_manager = SessionManagerFactory.get_manager(
209+
use_pooling=False
210+
)
209211
response = await session_manager.request(verb, url, **rest_kwargs)
210212

211213
if await self._has_expired_presigned_url(response):

src/snowflake/connector/aio/_wif_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
extract_iss_and_sub_without_signature_verification,
2222
get_aws_sts_hostname,
2323
)
24-
from ._session_manager import SessionManager
24+
from ._session_manager import SessionManager, SessionManagerFactory
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -187,7 +187,9 @@ async def create_attestation(
187187
"""
188188
entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
189189
session_manager = (
190-
session_manager.clone() if session_manager else SessionManager(use_pooling=True)
190+
session_manager.clone()
191+
if session_manager
192+
else SessionManagerFactory.get_manager(use_pooling=True)
191193
)
192194

193195
if provider == AttestationProvider.AWS:

src/snowflake/connector/result_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
[s._to_result_metadata_v1() for s in schema] if schema is not None else None
262262
)
263263
self._use_dict_result = use_dict_result
264-
# Passed to contain the configured Http behavior in case the connectio is no longer active for the download
264+
# Passed to contain the configured Http behavior in case the connection is no longer active for the download
265265
# Can be overridden with setters if needed.
266266
self._session_manager = session_manager
267267
self._metrics: dict[str, int] = {}

test/data/wiremock/mappings/auth/password/successful_flow.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
},
1919
"response": {
2020
"status": 200,
21+
"headers": { "Content-Type": "application/json" },
2122
"jsonBody": {
2223
"data": {
2324
"masterToken": "master token",

test/data/wiremock/mappings/queries/select_1_successful.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
},
1212
"response": {
1313
"status": 200,
14+
"headers": { "Content-Type": "application/json" },
1415
"jsonBody": {
1516
"data": {
1617
"parameters": [

test/data/wiremock/mappings/queries/select_large_request_successful.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
},
1212
"response": {
1313
"status": 200,
14+
"headers": { "Content-Type": "application/json" },
1415
"jsonBody": {
1516
"data": {
1617
"parameters": [

0 commit comments

Comments
 (0)