Skip to content

Commit 8ddb2e4

Browse files
committed
add optional dbt partner source
1 parent f72a118 commit 8ddb2e4

File tree

6 files changed

+16
-8
lines changed

6 files changed

+16
-8
lines changed

dbtsl/api/graphql/client/asyncio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
auth_token: str,
5050
url_format: Optional[str] = None,
5151
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
52+
client_partner_source: Optional[str] = None,
5253
):
5354
"""Initialize the metadata client.
5455
@@ -65,7 +66,7 @@ def __init__(
6566
limitations of `gql`'s `aiohttp` transport.
6667
See: https://github.com/graphql-python/gql/blob/b066e8944b0da0a4bbac6c31f43e5c3c7772cd51/gql/transport/aiohttp.py#L110
6768
"""
68-
super().__init__(server_host, environment_id, auth_token, url_format, timeout)
69+
super().__init__(server_host, environment_id, auth_token, url_format, timeout, client_partner_source)
6970

7071
@override
7172
def _create_transport(self, url: str, headers: Dict[str, str]) -> AIOHTTPTransport:

dbtsl/api/graphql/client/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def _default_backoff(cls) -> ExponentialBackoff:
4848
)
4949

5050
@classmethod
51-
def _extra_headers(cls) -> Dict[str, str]:
52-
return {
53-
"user-agent": env.PLATFORM.user_agent,
54-
}
51+
def _extra_headers(cls, client_partner_source: Optional[str] = None) -> Dict[str, str]:
52+
headers = {"user-agent": env.PLATFORM.user_agent}
53+
if client_partner_source is not None:
54+
headers["X-Dbt-Partner-Source"] = client_partner_source
55+
return headers
5556

5657
def __init__( # noqa: D107
5758
self,
@@ -60,6 +61,7 @@ def __init__( # noqa: D107
6061
auth_token: str,
6162
url_format: Optional[str] = None,
6263
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
64+
client_partner_source: Optional[str] = None,
6365
):
6466
self.environment_id = environment_id
6567

@@ -79,7 +81,7 @@ def __init__( # noqa: D107
7981

8082
headers = {
8183
"authorization": f"bearer {auth_token}",
82-
**self._extra_headers(),
84+
**self._extra_headers(client_partner_source),
8385
}
8486
transport = self._create_transport(url=server_url, headers=headers)
8587
self._gql = Client(transport=transport, execute_timeout=self.timeout.execute_timeout)

dbtsl/api/graphql/client/sync.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
auth_token: str,
3939
url_format: Optional[str] = None,
4040
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
41+
client_partner_source: Optional[str] = None,
4142
):
4243
"""Initialize the metadata client.
4344
@@ -49,11 +50,10 @@ def __init__(
4950
into a full URL. If `None`, the default `https://{server_host}/api/graphql`
5051
will be assumed.
5152
timeout: TimeoutOptions or total timeout (in seconds) for all GraphQL requests.
52-
5353
NOTE: If `timeout` is a `TimeoutOptions`, the `tls_close_timeout` will not be used, since
5454
`requests` does not support TLS termination timeouts.
5555
"""
56-
super().__init__(server_host, environment_id, auth_token, url_format, timeout)
56+
super().__init__(server_host, environment_id, auth_token, url_format, timeout, client_partner_source)
5757

5858
@override
5959
def _create_transport(self, url: str, headers: Dict[str, str]) -> RequestsHTTPTransport:

dbtsl/client/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
gql_factory: GraphQLClientFactory[TGQLClient],
4141
adbc_factory: ADBCClientFactory[TADBCClient],
4242
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
43+
requests_headers: Optional[dict[str, str]] = None,
4344
) -> None:
4445
"""Initialize the Semantic Layer client.
4546
@@ -50,6 +51,7 @@ def __init__(
5051
gql_factory: class of the underlying GQL client
5152
adbc_factory: class of the underlying ADBC client
5253
timeout: `TimeoutOptions` or total timeout for the underlying GraphQL client.
54+
requests_headers: additional headers to pass to the requests, optional
5355
"""
5456
self._has_session = False
5557

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ test = [
3535
"pytest-asyncio>=0.23.7,<0.24.0",
3636
"pytest-subtests>=0.12.1,<0.13.0",
3737
"pytest-mock>=3.14.0,<4.0.0",
38+
"python-dotenv>=1.0.0,<2.0.0",
3839
]
3940

4041
[tool.hatch.build]

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from gql import Client, gql
7+
from dotenv import load_dotenv
78
from gql.utilities.serialize_variable_values import serialize_variable_values
89

910

@@ -69,6 +70,7 @@ def from_env(cls) -> "Credentials":
6970
- SL_TOKEN
7071
- SL_ENV_ID
7172
"""
73+
load_dotenv()
7274
return cls(
7375
host=os.environ["SL_HOST"],
7476
token=os.environ["SL_TOKEN"],

0 commit comments

Comments
 (0)