1
+ import contextlib
1
2
import json
2
3
import shutil
3
- from typing import Any , Callable , Dict , Optional , Union
4
+ from contextvars import ContextVar
5
+ from typing import Any , Callable , Dict , Generator , Optional , Union
4
6
5
7
import requests
6
8
7
9
import audiostack
8
10
from audiostack .helpers .request_types import RequestTypes
9
11
12
+ _current_trace_id : ContextVar [Optional [str ]] = ContextVar (
13
+ "current_trace_id" , default = None
14
+ )
15
+
10
16
11
17
def remove_empty (data : Any ) -> Any :
12
18
if not (isinstance (data , dict ) or isinstance (data , list )):
@@ -32,14 +38,19 @@ def __init__(self, family: str) -> None:
32
38
self .family = family
33
39
34
40
@staticmethod
35
- def make_header () -> dict :
36
- header = {
41
+ def make_header (headers : Optional [ dict ] = None ) -> dict :
42
+ new_headers = {
37
43
"x-api-key" : audiostack .api_key ,
38
44
"x-python-sdk-version" : audiostack .sdk_version ,
39
45
}
46
+ current_trace_id = _current_trace_id .get ()
47
+ if current_trace_id is not None :
48
+ new_headers ["x-customer-trace-id" ] = current_trace_id
40
49
if audiostack .assume_org_id :
41
- header ["x-assume-org" ] = audiostack .assume_org_id
42
- return header
50
+ new_headers ["x-assume-org" ] = audiostack .assume_org_id
51
+ if headers :
52
+ new_headers .update (headers )
53
+ return new_headers
43
54
44
55
def resolve_response (self , r : Any ) -> dict :
45
56
if self .DEBUG_PRINT :
@@ -82,6 +93,7 @@ def send_request(
82
93
path_parameters : Optional [Union [dict , str ]] = None ,
83
94
query_parameters : Optional [Union [dict , str ]] = None ,
84
95
overwrite_base_url : Optional [str ] = None ,
96
+ headers : Optional [dict ] = None ,
85
97
) -> Any :
86
98
if overwrite_base_url :
87
99
url = overwrite_base_url
@@ -111,15 +123,15 @@ def send_request(
111
123
}
112
124
113
125
return self .resolve_response (
114
- FUNC_MAP [rtype ](url = url , json = json , headers = self .make_header ())
126
+ FUNC_MAP [rtype ](url = url , json = json , headers = self .make_header (headers ))
115
127
)
116
128
elif rtype == RequestTypes .GET :
117
129
if path_parameters :
118
130
url = f"{ url } /{ path_parameters } "
119
131
120
132
return self .resolve_response (
121
133
requests .get (
122
- url = url , params = query_parameters , headers = self .make_header ()
134
+ url = url , params = query_parameters , headers = self .make_header (headers )
123
135
)
124
136
)
125
137
elif rtype == RequestTypes .DELETE :
@@ -128,7 +140,7 @@ def send_request(
128
140
129
141
return self .resolve_response (
130
142
requests .delete (
131
- url = url , params = query_parameters , headers = self .make_header ()
143
+ url = url , params = query_parameters , headers = self .make_header (headers )
132
144
)
133
145
)
134
146
@@ -142,3 +154,12 @@ def download_url(cls, url: str, name: str, destination: str) -> None:
142
154
local_filename = f"{ destination } /{ name } "
143
155
with open (local_filename , "wb" ) as f :
144
156
shutil .copyfileobj (r .raw , f )
157
+
158
+
159
+ @contextlib .contextmanager
160
+ def use_trace (trace_id : str ) -> Generator [None , None , None ]:
161
+ token = _current_trace_id .set (trace_id )
162
+ try :
163
+ yield
164
+ finally :
165
+ _current_trace_id .reset (token )
0 commit comments