diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 49731d1f..080a3904 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -14,6 +14,7 @@ from unittest.mock import patch import httpretty +import pytest from httpretty import httprettified from requests import Session @@ -314,3 +315,26 @@ def test_description_is_none_when_cursor_is_not_executed(): connection = Connection("sample_trino_cluster:443") with connection.cursor() as cursor: assert hasattr(cursor, 'description') + + +@pytest.mark.parametrize( + "host, port, http_scheme_input_argument, http_scheme_set", + [ + # Infer from hostname + ("https://mytrinoserver.domain:9999", None, None, constants.HTTPS), + ("http://mytrinoserver.domain:9999", None, None, constants.HTTP), + # Infer from port + ("mytrinoserver.domain", constants.DEFAULT_TLS_PORT, None, constants.HTTPS), + ("mytrinoserver.domain", constants.DEFAULT_PORT, None, constants.HTTP), + # http_scheme parameter has higher precedence than port parameter + ("mytrinoserver.domain", constants.DEFAULT_TLS_PORT, constants.HTTP, constants.HTTP), + ("mytrinoserver.domain", constants.DEFAULT_PORT, constants.HTTPS, constants.HTTPS), + # Set explicitly by http_scheme parameter + ("mytrinoserver.domain", None, constants.HTTPS, constants.HTTPS), + # Default + ("mytrinoserver.domain", None, None, constants.HTTP), + ], +) +def test_setting_http_scheme(host, port, http_scheme_input_argument, http_scheme_set): + connection = Connection(host, port, http_scheme=http_scheme_input_argument) + assert connection.http_scheme == http_scheme_set diff --git a/trino/dbapi.py b/trino/dbapi.py index 9d2e2773..fb73cc38 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -150,7 +150,7 @@ def __init__( schema=constants.DEFAULT_SCHEMA, session_properties=None, http_headers=None, - http_scheme=constants.HTTP, + http_scheme=None, auth=constants.DEFAULT_AUTH, extra_credential=None, max_attempts=constants.DEFAULT_MAX_ATTEMPTS, @@ -202,7 +202,18 @@ def __init__( else: self._http_session = http_session self.http_headers = http_headers - self.http_scheme = http_scheme if not parsed_host.scheme else parsed_host.scheme + + # Set http_scheme + if parsed_host.scheme: + self.http_scheme = parsed_host.scheme + elif http_scheme: + self.http_scheme = http_scheme + elif port == constants.DEFAULT_TLS_PORT: + self.http_scheme = constants.HTTPS + elif port == constants.DEFAULT_PORT: + self.http_scheme = constants.HTTP + else: + self.http_scheme = constants.HTTP # Infer connection port: `hostname` takes precedence over explicit `port` argument # If none is given, use default based on HTTP protocol