diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7360b109..39c4eaa0 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -38,10 +38,10 @@ def sample_post_response_data(): """ yield { - "nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1", + "nextUri": "https://coordinator/v1/statement/20210817_140827_00000_arvdv/1", "id": "20210817_140827_00000_arvdv", "taskDownloadUris": [], - "infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv", + "infoUri": "https://coordinator/query.html?20210817_140827_00000_arvdv", "stats": { "scheduled": False, "runningSplits": 0, diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 41d96941..49731d1f 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -71,13 +71,13 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl # bind post statement to submit query httpretty.register_uri( method=httpretty.POST, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", body=get_statement_callback) # bind get token @@ -136,13 +136,13 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post # bind post statement to submit query httpretty.register_uri( method=httpretty.POST, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", body=get_statement_callback) # bind get token @@ -201,13 +201,13 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data, samp # bind post statement to submit query httpretty.register_uri( method=httpretty.POST, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", body=get_statement_callback) # bind get token @@ -281,7 +281,7 @@ def test_hostname_parsing(): https_server_without_port = Connection("https://mytrinoserver.domain") assert https_server_without_port.host == "mytrinoserver.domain" - assert https_server_without_port.port == 8080 + assert https_server_without_port.port == constants.DEFAULT_TLS_PORT assert https_server_without_port.http_scheme == constants.HTTPS http_server_with_port = Connection("http://mytrinoserver.domain:9999") @@ -291,22 +291,22 @@ def test_hostname_parsing(): http_server_without_port = Connection("http://mytrinoserver.domain") assert http_server_without_port.host == "mytrinoserver.domain" - assert http_server_without_port.port == 8080 + assert http_server_without_port.port == constants.DEFAULT_PORT assert http_server_without_port.http_scheme == constants.HTTP http_server_with_path = Connection("http://mytrinoserver.domain/some_path") assert http_server_with_path.host == "mytrinoserver.domain/some_path" - assert http_server_with_path.port == 8080 + assert http_server_with_path.port == constants.DEFAULT_PORT assert http_server_with_path.http_scheme == constants.HTTP only_hostname = Connection("mytrinoserver.domain") assert only_hostname.host == "mytrinoserver.domain" - assert only_hostname.port == 8080 + assert only_hostname.port == constants.DEFAULT_PORT assert only_hostname.http_scheme == constants.HTTP only_hostname_with_path = Connection("mytrinoserver.domain/some_path") assert only_hostname_with_path.host == "mytrinoserver.domain/some_path" - assert only_hostname_with_path.port == 8080 + assert only_hostname_with_path.port == constants.DEFAULT_PORT assert only_hostname_with_path.http_scheme == constants.HTTP diff --git a/trino/dbapi.py b/trino/dbapi.py index f989cced..9d2e2773 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -143,7 +143,7 @@ class Connection: def __init__( self, host: str, - port=constants.DEFAULT_PORT, + port=None, user=None, source=constants.DEFAULT_SOURCE, catalog=constants.DEFAULT_CATALOG, @@ -176,7 +176,6 @@ def __init__( ] self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path - self.port = port if parsed_host.port is None else parsed_host.port self.user = user self.source = source self.catalog = catalog @@ -204,6 +203,16 @@ def __init__( self._http_session = http_session self.http_headers = http_headers self.http_scheme = http_scheme if not parsed_host.scheme else parsed_host.scheme + + # Infer connection port: `hostname` takes precedence over explicit `port` argument + # If none is given, use default based on HTTP protocol + default_port = constants.DEFAULT_TLS_PORT if self.http_scheme == constants.HTTPS else constants.DEFAULT_PORT + self.port = ( + parsed_host.port if parsed_host.port is not None + else port if port is not None + else default_port + ) + self.auth = auth self.extra_credential = extra_credential self.max_attempts = max_attempts