diff --git a/polygon/rest/base.py b/polygon/rest/base.py index d9d4768a..232af02e 100644 --- a/polygon/rest/base.py +++ b/polygon/rest/base.py @@ -2,9 +2,10 @@ import json import urllib3 import inspect +from urllib3 import PoolManager, ProxyManager from urllib3.util.retry import Retry from enum import Enum -from typing import Optional, Any, Dict +from typing import Optional, Any, Dict, Union from datetime import datetime from importlib.metadata import version, PackageNotFoundError from .models.request import RequestOptionBuilder @@ -22,6 +23,8 @@ class BaseClient: + client: Union[PoolManager, ProxyManager] + def __init__( self, api_key: Optional[str], @@ -33,6 +36,7 @@ def __init__( verbose: bool, trace: bool, custom_json: Optional[Any] = None, + proxy: Optional[str] = None, ): if api_key is None: raise AuthError( @@ -66,15 +70,25 @@ def __init__( backoff_factor=0.1, # [0.0s, 0.2s, 0.4s, 0.8s, 1.6s, ...] ) - # https://urllib3.readthedocs.io/en/stable/reference/urllib3.poolmanager.html - # https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool - self.client = urllib3.PoolManager( - num_pools=num_pools, - headers=self.headers, # default headers sent with each request. - ca_certs=certifi.where(), - cert_reqs="CERT_REQUIRED", - retries=retry_strategy, # use the customized Retry instance - ) + if proxy: + self.client = urllib3.ProxyManager( + proxy_url=proxy, # e.g., "http://proxy.example.com" + num_pools=num_pools, + headers=self.headers, + ca_certs=certifi.where(), + cert_reqs="CERT_REQUIRED", + retries=retry_strategy, + ) + else: + # https://urllib3.readthedocs.io/en/stable/reference/urllib3.poolmanager.html + # https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool + self.client = urllib3.PoolManager( + num_pools=num_pools, + headers=self.headers, + ca_certs=certifi.where(), + cert_reqs="CERT_REQUIRED", + retries=retry_strategy, + ) self.timeout = urllib3.Timeout(connect=connect_timeout, read=read_timeout) @@ -102,6 +116,12 @@ def _get( headers = self._concat_headers(option.headers) + # Check if path is a full URL or a relative path + if path.startswith("http"): + url = path + else: + url = self.BASE + path + if self.trace: full_url = f"{self.BASE}{path}" if params: @@ -228,7 +248,7 @@ def _paginate_iter( for t in decoded[result_key]: yield deserializer(t) if "next_url" in decoded: - path = decoded["next_url"].replace(self.BASE, "") + path = decoded["next_url"] # Use full next_url params = {} else: return