diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 909823ee..85ac25f7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -130,8 +130,8 @@ def assert_headers(headers): assert headers[constants.HEADER_CATALOG] == catalog assert headers[constants.HEADER_SCHEMA] == schema assert headers[constants.HEADER_SOURCE] == source - assert headers[constants.HEADER_USER] == user - assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user + assert headers[constants.HEADER_ORIGINAL_USER] == user + assert headers[constants.HEADER_USER] == authorization_user assert headers[constants.HEADER_SESSION] == "" assert headers[constants.HEADER_TRANSACTION] is None assert headers[constants.HEADER_TIMEZONE] == timezone diff --git a/trino/auth.py b/trino/auth.py index 306539f0..783d0f5e 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -34,6 +34,7 @@ import trino.logging from trino import exceptions +from trino.constants import HEADER_ORIGINAL_USER from trino.constants import HEADER_USER from trino.constants import MAX_NT_PASSWORD_SIZE @@ -552,7 +553,7 @@ def _determine_host(url: Optional[str]) -> Any: @staticmethod def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]: - return headers.get(HEADER_USER) + return headers.get(HEADER_ORIGINAL_USER, headers.get(HEADER_USER)) @staticmethod def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[str]: diff --git a/trino/client.py b/trino/client.py index 7cc1f0f2..cf8ec9d1 100644 --- a/trino/client.py +++ b/trino/client.py @@ -511,8 +511,11 @@ def http_headers(self) -> CaseInsensitiveDict[str]: headers[constants.HEADER_CATALOG] = self._client_session.catalog headers[constants.HEADER_SCHEMA] = self._client_session.schema headers[constants.HEADER_SOURCE] = self._client_session.source - headers[constants.HEADER_USER] = self._client_session.user - headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user + if self._client_session.authorization_user is not None: + headers[constants.HEADER_ORIGINAL_USER] = self._client_session.user + headers[constants.HEADER_USER] = self._client_session.authorization_user + else: + headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone if self._client_session.encoding is None: pass diff --git a/trino/constants.py b/trino/constants.py index 46b6e9ec..b136aaaf 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -33,6 +33,7 @@ HEADER_SCHEMA = "X-Trino-Schema" HEADER_SOURCE = "X-Trino-Source" HEADER_USER = "X-Trino-User" +HEADER_ORIGINAL_USER = "X-Trino-Original-User" HEADER_CLIENT_INFO = "X-Trino-Client-Info" HEADER_CLIENT_TAGS = "X-Trino-Client-Tags" HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential" @@ -61,7 +62,6 @@ CLIENT_CAPABILITY_SESSION_AUTHORIZATION = "SESSION_AUTHORIZATION" CLIENT_CAPABILITIES = ','.join([CLIENT_CAPABILITY_PARAMETRIC_DATETIME, CLIENT_CAPABILITY_SESSION_AUTHORIZATION]) -HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User" HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User" HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User"