@@ -131,7 +131,12 @@ def add_chat_completion_step_to_trace(**kwargs) -> None:
131
131
132
132
133
133
# ----------------------------- Tracing decorator ---------------------------- #
134
- def trace (* step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [str ] = None , ** step_kwargs ):
134
+ def trace (
135
+ * step_args ,
136
+ inference_pipeline_id : Optional [str ] = None ,
137
+ context_kwarg : Optional [str ] = None ,
138
+ ** step_kwargs ,
139
+ ):
135
140
"""Decorator to trace a function.
136
141
137
142
Examples
@@ -175,7 +180,9 @@ def decorator(func):
175
180
def wrapper (* func_args , ** func_kwargs ):
176
181
if step_kwargs .get ("name" ) is None :
177
182
step_kwargs ["name" ] = func .__name__
178
- with create_step (* step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs ) as step :
183
+ with create_step (
184
+ * step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs
185
+ ) as step :
179
186
output = exception = None
180
187
try :
181
188
output = func (* func_args , ** func_kwargs )
@@ -196,7 +203,10 @@ def wrapper(*func_args, **func_kwargs):
196
203
if context_kwarg in inputs :
197
204
log_context (inputs .get (context_kwarg ))
198
205
else :
199
- logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
206
+ logger .warning (
207
+ "Context kwarg `%s` not found in inputs of the current function." ,
208
+ context_kwarg ,
209
+ )
200
210
201
211
step .log (
202
212
inputs = inputs ,
@@ -215,7 +225,10 @@ def wrapper(*func_args, **func_kwargs):
215
225
216
226
217
227
def trace_async (
218
- * step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [str ] = None , ** step_kwargs
228
+ * step_args ,
229
+ inference_pipeline_id : Optional [str ] = None ,
230
+ context_kwarg : Optional [str ] = None ,
231
+ ** step_kwargs ,
219
232
):
220
233
"""Decorator to trace a function.
221
234
@@ -260,7 +273,9 @@ def decorator(func):
260
273
async def wrapper (* func_args , ** func_kwargs ):
261
274
if step_kwargs .get ("name" ) is None :
262
275
step_kwargs ["name" ] = func .__name__
263
- with create_step (* step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs ) as step :
276
+ with create_step (
277
+ * step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs
278
+ ) as step :
264
279
output = exception = None
265
280
try :
266
281
output = await func (* func_args , ** func_kwargs )
@@ -281,7 +296,10 @@ async def wrapper(*func_args, **func_kwargs):
281
296
if context_kwarg in inputs :
282
297
log_context (inputs .get (context_kwarg ))
283
298
else :
284
- logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
299
+ logger .warning (
300
+ "Context kwarg `%s` not found in inputs of the current function." ,
301
+ context_kwarg ,
302
+ )
285
303
286
304
step .log (
287
305
inputs = inputs ,
@@ -299,7 +317,9 @@ async def wrapper(*func_args, **func_kwargs):
299
317
return decorator
300
318
301
319
302
- async def _invoke_with_context (coroutine : Awaitable [Any ]) -> Tuple [contextvars .Context , Any ]:
320
+ async def _invoke_with_context (
321
+ coroutine : Awaitable [Any ],
322
+ ) -> Tuple [contextvars .Context , Any ]:
303
323
"""Runs a coroutine and preserves the context variables set within it."""
304
324
result = await coroutine
305
325
context = contextvars .copy_context ()
@@ -356,6 +376,7 @@ def post_process_trace(
356
376
"cost" : processed_steps [0 ].get ("cost" , 0 ),
357
377
"tokens" : processed_steps [0 ].get ("tokens" , 0 ),
358
378
"steps" : processed_steps ,
379
+ ** root_step .metadata ,
359
380
}
360
381
if input_variables :
361
382
trace_data .update (input_variables )
0 commit comments