Skip to content

Commit b167668

Browse files
[async] Applied #2451 to async code - test passing, ProxySessionManager for async with SessionWithProxy
1 parent c6153e6 commit b167668

15 files changed

+499
-60
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Error,
2222
OperationalError,
2323
ProgrammingError,
24-
proxy,
2524
)
2625

2726
from .._query_context_cache import QueryContextCache
@@ -80,6 +79,7 @@
8079
from ._session_manager import (
8180
AioHttpConfig,
8281
SessionManager,
82+
SessionManagerFactory,
8383
SnowflakeSSLConnectorFactory,
8484
)
8585
from ._telemetry import TelemetryClient
@@ -191,10 +191,6 @@ async def __open_connection(self):
191191
use_numpy=self._numpy, support_negative_year=self._support_negative_year
192192
)
193193

194-
proxy.set_proxies(
195-
self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password
196-
)
197-
198194
self._rest = SnowflakeRestful(
199195
host=self.host,
200196
port=self.port,
@@ -1014,13 +1010,17 @@ async def connect(self, **kwargs) -> None:
10141010
else:
10151011
self.__config(**self._conn_parameters)
10161012

1017-
self._http_config = AioHttpConfig(
1013+
self._http_config: AioHttpConfig = AioHttpConfig(
10181014
connector_factory=SnowflakeSSLConnectorFactory(),
10191015
use_pooling=not self.disable_request_pooling,
1016+
proxy_host=self.proxy_host,
1017+
proxy_port=self.proxy_port,
1018+
proxy_user=self.proxy_user,
1019+
proxy_password=self.proxy_password,
10201020
snowflake_ocsp_mode=self._ocsp_mode(),
10211021
trust_env=True, # Required for proxy support via environment variables
10221022
)
1023-
self._session_manager = SessionManager(self._http_config)
1023+
self._session_manager = SessionManagerFactory.get_manager(self._http_config)
10241024

10251025
if self.enable_connection_diag:
10261026
raise NotImplementedError(

src/snowflake/connector/aio/_network.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import TYPE_CHECKING, Any, AsyncGenerator
1111

1212
import OpenSSL.SSL
13-
from urllib3.util.url import parse_url
1413

1514
from ..compat import FORBIDDEN, OK, UNAUTHORIZED, urlencode, urlparse, urlsplit
1615
from ..constants import (
@@ -79,7 +78,11 @@
7978
)
8079
from ..time_util import TimeoutBackoffCtx
8180
from ._description import CLIENT_NAME
82-
from ._session_manager import SessionManager, SnowflakeSSLConnectorFactory
81+
from ._session_manager import (
82+
SessionManager,
83+
SessionManagerFactory,
84+
SnowflakeSSLConnectorFactory,
85+
)
8386

8487
if TYPE_CHECKING:
8588
from snowflake.connector.aio import SnowflakeConnection
@@ -145,15 +148,12 @@ def __init__(
145148
session_manager = (
146149
connection._session_manager
147150
if (connection and connection._session_manager)
148-
else SessionManager(connector_factory=SnowflakeSSLConnectorFactory())
151+
else SessionManagerFactory.get_manager(
152+
connector_factory=SnowflakeSSLConnectorFactory()
153+
)
149154
)
150155
self._session_manager = session_manager
151156

152-
if self._connection and self._connection.proxy_host:
153-
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
154-
else:
155-
self._get_proxy_headers = lambda _: None
156-
157157
async def close(self) -> None:
158158
if hasattr(self, "_token"):
159159
del self._token
@@ -737,7 +737,6 @@ async def _request_exec(
737737
headers=headers,
738738
data=input_data,
739739
timeout=aiohttp.ClientTimeout(socket_timeout),
740-
proxy_headers=self._get_proxy_headers(full_url),
741740
)
742741
try:
743742
if raw_ret.status == OK:

src/snowflake/connector/aio/_result_batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
raise_failed_request_error,
1414
raise_okta_unauthorized_error,
1515
)
16-
from snowflake.connector.aio._session_manager import SessionManager
16+
from snowflake.connector.aio._session_manager import SessionManagerFactory
1717
from snowflake.connector.aio._time_util import TimerContextManager
1818
from snowflake.connector.arrow_context import ArrowConverterContext
1919
from snowflake.connector.backoff_policies import exponential_backoff
@@ -261,7 +261,9 @@ async def download_chunk(http_session):
261261
logger.debug(
262262
f"downloading result batch id: {self.id} with new session through local session manager"
263263
)
264-
local_session_manager = SessionManager(use_pooling=False)
264+
local_session_manager = SessionManagerFactory.get_manager(
265+
use_pooling=False
266+
)
265267
async with local_session_manager.use_session() as session:
266268
response, content, encoding = await download_chunk(session)
267269

src/snowflake/connector/aio/_session_manager.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import sys
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Unpack
55

66
from aiohttp import ClientRequest, ClientTimeout
7+
from aiohttp.client import _RequestOptions
78
from aiohttp.client_proto import ResponseHandler
89
from aiohttp.connector import Connection
10+
from aiohttp.typedefs import StrOrURL
911

1012
from .. import OperationalError
1113
from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED
@@ -44,10 +46,10 @@ def __init__(
4446
):
4547
self._snowflake_ocsp_mode = snowflake_ocsp_mode
4648
if session_manager is None:
47-
logger.debug(
48-
"SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance so please verify why it isn't true in the current context"
49+
logger.warning(
50+
"SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance - verify why it isn't true in the current context"
4951
)
50-
session_manager = SessionManager()
52+
session_manager = SessionManagerFactory.get_manager()
5153
self._session_manager = session_manager
5254
if self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN and sys.version_info < (
5355
3,
@@ -345,24 +347,38 @@ def __init__(
345347
lambda: SessionPool(self)
346348
)
347349

350+
@classmethod
351+
def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager:
352+
"""Build a new manager from *cfg*, optionally overriding fields.
353+
354+
Example::
355+
356+
no_pool_cfg = conn._http_config.copy_with(use_pooling=False)
357+
manager = SessionManager.from_config(no_pool_cfg)
358+
"""
359+
360+
if overrides:
361+
cfg = cfg.copy_with(**overrides)
362+
return cls(config=cfg)
363+
348364
@property
349365
def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]:
350366
return self._cfg.connector_factory
351367

352368
@connector_factory.setter
353369
def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None:
354-
self._cfg = self._cfg.copy_with(connector_factory=value)
370+
self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value)
355371

356372
def make_session(self) -> aiohttp.ClientSession:
357373
"""Create a new aiohttp.ClientSession with configured connector."""
358374
connector = self._cfg.get_connector(
359375
session_manager=self.clone(),
360376
snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode,
361377
)
362-
363378
return aiohttp.ClientSession(
364379
connector=connector,
365380
trust_env=self._cfg.trust_env,
381+
proxy=self.proxy_url,
366382
)
367383

368384
@contextlib.asynccontextmanager
@@ -425,7 +441,7 @@ def clone(
425441
if connector_factory is not None:
426442
overrides["connector_factory"] = connector_factory
427443

428-
return SessionManager.from_config(self._cfg, **overrides)
444+
return self.from_config(self._cfg, **overrides)
429445

430446

431447
async def request(
@@ -454,3 +470,73 @@ async def request(
454470
use_pooling=use_pooling,
455471
**kwargs,
456472
)
473+
474+
475+
class ProxySessionManager(SessionManager):
476+
class SessionWithProxy(aiohttp.ClientSession):
477+
async def request(
478+
self,
479+
method: str,
480+
url: StrOrURL,
481+
**kwargs: Unpack[_RequestOptions],
482+
):
483+
# Inject Host header when proxying
484+
try:
485+
# respect caller-provided proxy and proxy_headers if any
486+
provided_proxy = kwargs.get("proxy") or self._default_proxy
487+
provided_proxy_headers = kwargs.get("proxy_headers")
488+
if provided_proxy is not None:
489+
authority = urlparse(str(url)).netloc
490+
if provided_proxy_headers is None:
491+
kwargs["proxy_headers"] = {"Host": authority}
492+
elif "Host" not in provided_proxy_headers:
493+
provided_proxy_headers["Host"] = authority
494+
else:
495+
logger.debug(
496+
"Host header was already set - not overriding with netloc at the ClientSession.request method level."
497+
)
498+
except Exception:
499+
logger.warning(
500+
"Failed to compute proxy settings for %s",
501+
urlparse(url).hostname,
502+
exc_info=True,
503+
)
504+
return await super().request(method, url, **kwargs)
505+
506+
def make_session(self) -> aiohttp.ClientSession:
507+
connector = self._cfg.get_connector(
508+
session_manager=self.clone(),
509+
snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode,
510+
)
511+
# Construct session with base proxy set, request() may override per-URL when bypassing
512+
return self.SessionWithProxy(
513+
connector=connector,
514+
trust_env=self._cfg.trust_env,
515+
proxy=self.proxy_url,
516+
)
517+
518+
519+
class SessionManagerFactory:
520+
@staticmethod
521+
def get_manager(
522+
config: AioHttpConfig | None = None, **http_config_kwargs
523+
) -> SessionManager:
524+
"""Return a proxy-aware or plain async SessionManager based on config.
525+
526+
If any explicit proxy parameters are provided (in config or kwargs),
527+
return ProxySessionManager; otherwise return the base SessionManager.
528+
"""
529+
530+
def _has_proxy_params(cfg: AioHttpConfig | None, kwargs: dict) -> bool:
531+
cfg_keys = (
532+
"proxy_host",
533+
"proxy_port",
534+
)
535+
in_cfg = any(getattr(cfg, k, None) for k in cfg_keys) if cfg else False
536+
in_kwargs = "proxy" in kwargs
537+
return in_cfg or in_kwargs
538+
539+
if _has_proxy_params(config, http_config_kwargs):
540+
return ProxySessionManager(config, **http_config_kwargs)
541+
else:
542+
return SessionManager(config, **http_config_kwargs)

src/snowflake/connector/aio/_storage_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..encryption_util import SnowflakeEncryptionUtil
1616
from ..errors import RequestExceedMaxRetryError
1717
from ..storage_client import SnowflakeStorageClient as SnowflakeStorageClientSync
18-
from ._session_manager import SessionManager
18+
from ._session_manager import SessionManagerFactory
1919

2020
if TYPE_CHECKING: # pragma: no cover
2121
from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential
@@ -205,7 +205,9 @@ async def _send_request_with_retry(
205205
# SessionManager on the fly, if code ends up here, since we probably do not care about losing
206206
# proxy or HTTP setup.
207207
logger.debug("storage client request with new session")
208-
session_manager = SessionManager(use_pooling=False)
208+
session_manager = SessionManagerFactory.get_manager(
209+
use_pooling=False
210+
)
209211
response = await session_manager.request(verb, url, **rest_kwargs)
210212

211213
if await self._has_expired_presigned_url(response):

src/snowflake/connector/aio/_wif_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
extract_iss_and_sub_without_signature_verification,
2222
get_aws_sts_hostname,
2323
)
24-
from ._session_manager import SessionManager
24+
from ._session_manager import SessionManager, SessionManagerFactory
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -187,7 +187,9 @@ async def create_attestation(
187187
"""
188188
entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
189189
session_manager = (
190-
session_manager.clone() if session_manager else SessionManager(use_pooling=True)
190+
session_manager.clone()
191+
if session_manager
192+
else SessionManagerFactory.get_manager(use_pooling=True)
191193
)
192194

193195
if provider == AttestationProvider.AWS:

src/snowflake/connector/result_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
[s._to_result_metadata_v1() for s in schema] if schema is not None else None
262262
)
263263
self._use_dict_result = use_dict_result
264-
# Passed to contain the configured Http behavior in case the connectio is no longer active for the download
264+
# Passed to contain the configured Http behavior in case the connection is no longer active for the download
265265
# Can be overridden with setters if needed.
266266
self._session_manager = session_manager
267267
self._metrics: dict[str, int] = {}

test/data/wiremock/mappings/auth/password/successful_flow.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
},
1919
"response": {
2020
"status": 200,
21+
"headers": { "Content-Type": "application/json" },
2122
"jsonBody": {
2223
"data": {
2324
"masterToken": "master token",

test/data/wiremock/mappings/queries/select_1_successful.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
},
1212
"response": {
1313
"status": 200,
14+
"headers": { "Content-Type": "application/json" },
1415
"jsonBody": {
1516
"data": {
1617
"parameters": [

test/data/wiremock/mappings/queries/select_large_request_successful.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
},
1212
"response": {
1313
"status": 200,
14+
"headers": { "Content-Type": "application/json" },
1415
"jsonBody": {
1516
"data": {
1617
"parameters": [

0 commit comments

Comments
 (0)