diff --git a/dbtsl/api/graphql/client/asyncio.py b/dbtsl/api/graphql/client/asyncio.py index 20e861b..a426638 100644 --- a/dbtsl/api/graphql/client/asyncio.py +++ b/dbtsl/api/graphql/client/asyncio.py @@ -49,6 +49,7 @@ def __init__( auth_token: str, url_format: Optional[str] = None, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ): """Initialize the metadata client. @@ -60,12 +61,13 @@ def __init__( into a full URL. If `None`, the default `https://{server_host}/api/graphql` will be assumed. timeout: TimeoutOptions or total timeout (in seconds) for all GraphQL requests. + client_partner_source: Pass a dbt partner source header for traffic source tracking NOTE: If `timeout` is a `TimeoutOptions`, the `connect_timeout` will not be used, due to limitations of `gql`'s `aiohttp` transport. See: https://github.com/graphql-python/gql/blob/b066e8944b0da0a4bbac6c31f43e5c3c7772cd51/gql/transport/aiohttp.py#L110 """ - super().__init__(server_host, environment_id, auth_token, url_format, timeout) + super().__init__(server_host, environment_id, auth_token, url_format, timeout, client_partner_source) @override def _create_transport(self, url: str, headers: Dict[str, str]) -> AIOHTTPTransport: diff --git a/dbtsl/api/graphql/client/asyncio.pyi b/dbtsl/api/graphql/client/asyncio.pyi index 2ab4e7d..df36b4d 100644 --- a/dbtsl/api/graphql/client/asyncio.pyi +++ b/dbtsl/api/graphql/client/asyncio.pyi @@ -22,6 +22,7 @@ class AsyncGraphQLClient: auth_token: str, url_format: Optional[str] = None, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: ... def session(self) -> AbstractAsyncContextManager[AsyncIterator[Self]]: ... @property diff --git a/dbtsl/api/graphql/client/base.py b/dbtsl/api/graphql/client/base.py index 0dbf119..f13899f 100644 --- a/dbtsl/api/graphql/client/base.py +++ b/dbtsl/api/graphql/client/base.py @@ -48,10 +48,8 @@ def _default_backoff(cls) -> ExponentialBackoff: ) @classmethod - def _extra_headers(cls) -> Dict[str, str]: - return { - "user-agent": env.PLATFORM.user_agent, - } + def _extra_headers(cls, client_partner_source: Optional[str] = None) -> Dict[str, str]: + return {"user-agent": env.PLATFORM.user_agent, "x-dbt-partner-source": client_partner_source or "sl-python-sdk"} def __init__( # noqa: D107 self, @@ -60,6 +58,7 @@ def __init__( # noqa: D107 auth_token: str, url_format: Optional[str] = None, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ): self.environment_id = environment_id @@ -79,7 +78,7 @@ def __init__( # noqa: D107 headers = { "authorization": f"bearer {auth_token}", - **self._extra_headers(), + **self._extra_headers(client_partner_source), } transport = self._create_transport(url=server_url, headers=headers) self._gql = Client(transport=transport, execute_timeout=self.timeout.execute_timeout) @@ -144,6 +143,7 @@ def __call__( auth_token: str, url_format: Optional[str] = None, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> TClient: """Initialize the Semantic Layer client. @@ -153,5 +153,6 @@ def __call__( auth_token: the API auth token url_format: the URL format string to construct the final URL with timeout: `TimeoutOptions` or total timeout + client_partner_source: Pass a dbt partner source header for traffic source tracking """ pass diff --git a/dbtsl/api/graphql/client/sync.py b/dbtsl/api/graphql/client/sync.py index 9964c20..fe36589 100644 --- a/dbtsl/api/graphql/client/sync.py +++ b/dbtsl/api/graphql/client/sync.py @@ -38,6 +38,7 @@ def __init__( auth_token: str, url_format: Optional[str] = None, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ): """Initialize the metadata client. @@ -49,11 +50,12 @@ def __init__( into a full URL. If `None`, the default `https://{server_host}/api/graphql` will be assumed. timeout: TimeoutOptions or total timeout (in seconds) for all GraphQL requests. + client_partner_source: Pass a dbt partner source header for traffic source tracking NOTE: If `timeout` is a `TimeoutOptions`, the `tls_close_timeout` will not be used, since `requests` does not support TLS termination timeouts. """ - super().__init__(server_host, environment_id, auth_token, url_format, timeout) + super().__init__(server_host, environment_id, auth_token, url_format, timeout, client_partner_source) @override def _create_transport(self, url: str, headers: Dict[str, str]) -> RequestsHTTPTransport: diff --git a/dbtsl/api/graphql/client/sync.pyi b/dbtsl/api/graphql/client/sync.pyi index 24410e2..ddd568d 100644 --- a/dbtsl/api/graphql/client/sync.pyi +++ b/dbtsl/api/graphql/client/sync.pyi @@ -22,6 +22,7 @@ class SyncGraphQLClient: auth_token: str, url_format: Optional[str] = None, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: ... def session(self) -> AbstractContextManager[Iterator[Self]]: ... @property diff --git a/dbtsl/client/asyncio.py b/dbtsl/client/asyncio.py index 98ffcc9..5ad6870 100644 --- a/dbtsl/client/asyncio.py +++ b/dbtsl/client/asyncio.py @@ -26,6 +26,7 @@ def __init__( auth_token: str, host: str, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: """Initialize the Semantic Layer client. @@ -42,6 +43,7 @@ def __init__( gql_factory=AsyncGraphQLClient, adbc_factory=AsyncADBCClient, timeout=timeout, + client_partner_source=client_partner_source, ) @asynccontextmanager diff --git a/dbtsl/client/asyncio.pyi b/dbtsl/client/asyncio.pyi index a512fcc..647661f 100644 --- a/dbtsl/client/asyncio.pyi +++ b/dbtsl/client/asyncio.pyi @@ -15,6 +15,7 @@ class AsyncSemanticLayerClient: auth_token: str, host: str, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: ... @overload async def compile_sql( diff --git a/dbtsl/client/base.py b/dbtsl/client/base.py index de3c3e5..a891edc 100644 --- a/dbtsl/client/base.py +++ b/dbtsl/client/base.py @@ -40,6 +40,7 @@ def __init__( gql_factory: GraphQLClientFactory[TGQLClient], adbc_factory: ADBCClientFactory[TADBCClient], timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: """Initialize the Semantic Layer client. @@ -50,6 +51,7 @@ def __init__( gql_factory: class of the underlying GQL client adbc_factory: class of the underlying ADBC client timeout: `TimeoutOptions` or total timeout for the underlying GraphQL client. + client_partner_source: Pass a dbt partner source header for traffic source tracking """ self._has_session = False @@ -61,6 +63,7 @@ def __init__( auth_token=auth_token, url_format=env.GRAPHQL_URL_FORMAT, timeout=timeout, + client_partner_source=client_partner_source, ) self._adbc = adbc_factory( server_host=host, diff --git a/dbtsl/client/sync.py b/dbtsl/client/sync.py index 415bfce..79ab588 100644 --- a/dbtsl/client/sync.py +++ b/dbtsl/client/sync.py @@ -26,6 +26,7 @@ def __init__( auth_token: str, host: str, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: """Initialize the Semantic Layer client. @@ -42,6 +43,7 @@ def __init__( gql_factory=SyncGraphQLClient, adbc_factory=SyncADBCClient, timeout=timeout, + client_partner_source=client_partner_source, ) @contextmanager diff --git a/dbtsl/client/sync.pyi b/dbtsl/client/sync.pyi index b0f8ec8..f86507a 100644 --- a/dbtsl/client/sync.pyi +++ b/dbtsl/client/sync.pyi @@ -15,6 +15,7 @@ class SyncSemanticLayerClient: auth_token: str, host: str, timeout: Optional[Union[TimeoutOptions, float, int]] = None, + client_partner_source: Optional[str] = None, ) -> None: ... @overload def compile_sql( diff --git a/pyproject.toml b/pyproject.toml index dcd273d..637d536 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ test = [ "pytest-asyncio>=0.23.7,<0.24.0", "pytest-subtests>=0.12.1,<0.13.0", "pytest-mock>=3.14.0,<4.0.0", + "python-dotenv>=1.0.0,<2.0.0", ] [tool.hatch.build] diff --git a/tests/conftest.py b/tests/conftest.py index 62e5d94..10eb472 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Union, cast import pytest +from dotenv import load_dotenv from gql import Client, gql from gql.utilities.serialize_variable_values import serialize_variable_values @@ -69,6 +70,7 @@ def from_env(cls) -> "Credentials": - SL_TOKEN - SL_ENV_ID """ + load_dotenv() return cls( host=os.environ["SL_HOST"], token=os.environ["SL_TOKEN"], diff --git a/tests/integration/test_sl_client.py b/tests/integration/test_sl_client.py index 871af3d..779e0c7 100644 --- a/tests/integration/test_sl_client.py +++ b/tests/integration/test_sl_client.py @@ -20,6 +20,7 @@ async def async_client(credentials: Credentials) -> AsyncIterator[AsyncSemanticL environment_id=credentials.environment_id, auth_token=credentials.token, host=credentials.host, + client_partner_source="dbt-e2e-tests", ) async with client.session(): yield client @@ -31,6 +32,7 @@ def sync_client(credentials: Credentials) -> Iterator[SyncSemanticLayerClient]: environment_id=credentials.environment_id, auth_token=credentials.token, host=credentials.host, + client_partner_source="dbt-e2e-tests", ) with client.session(): yield client diff --git a/tests/integration/test_timeouts.py b/tests/integration/test_timeouts.py index ea66a24..71a5847 100644 --- a/tests/integration/test_timeouts.py +++ b/tests/integration/test_timeouts.py @@ -44,6 +44,7 @@ def factory(timeout: TimeoutOptions) -> SyncSemanticLayerClient: auth_token=credentials.token, host=credentials.host, timeout=timeout, + client_partner_source="dbt-e2e-tests", ) return factory