@@ -136,6 +136,9 @@ def _get_modified_create_chat_completion(self) -> callable:
136
136
def modified_create_chat_completion (* args , ** kwargs ) -> str :
137
137
stream = kwargs .get ("stream" , False )
138
138
139
+ # Pop the reserved Openlayer kwargs
140
+ inference_id = kwargs .pop ("inference_id" , None )
141
+
139
142
if not stream :
140
143
start_time = time .time ()
141
144
response = self .create_chat_completion (* args , ** kwargs )
@@ -169,21 +172,26 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
169
172
num_input_tokens = response .usage .prompt_tokens ,
170
173
num_output_tokens = response .usage .completion_tokens ,
171
174
)
172
-
173
- self ._add_to_trace (
174
- end_time = end_time ,
175
- inputs = {
175
+ trace_args = {
176
+ "end_time" : end_time ,
177
+ "inputs" : {
176
178
"prompt" : kwargs ["messages" ],
177
179
},
178
- output = output_data ,
179
- latency = (end_time - start_time ) * 1000 ,
180
- tokens = response .usage .total_tokens ,
181
- cost = cost ,
182
- prompt_tokens = response .usage .prompt_tokens ,
183
- completion_tokens = response .usage .completion_tokens ,
184
- model = response .model ,
185
- model_parameters = kwargs .get ("model_parameters" ),
186
- raw_output = response .model_dump (),
180
+ "output" : output_data ,
181
+ "latency" : (end_time - start_time ) * 1000 ,
182
+ "tokens" : response .usage .total_tokens ,
183
+ "cost" : cost ,
184
+ "prompt_tokens" : response .usage .prompt_tokens ,
185
+ "completion_tokens" : response .usage .completion_tokens ,
186
+ "model" : response .model ,
187
+ "model_parameters" : kwargs .get ("model_parameters" ),
188
+ "raw_output" : response .model_dump (),
189
+ }
190
+ if inference_id :
191
+ trace_args ["id" ] = str (inference_id )
192
+
193
+ self ._add_to_trace (
194
+ ** trace_args ,
187
195
)
188
196
# pylint: disable=broad-except
189
197
except Exception as e :
@@ -267,28 +275,33 @@ def stream_chunks():
267
275
else 0
268
276
),
269
277
)
270
-
271
- self ._add_to_trace (
272
- end_time = end_time ,
273
- inputs = {
278
+ trace_args = {
279
+ "end_time" : end_time ,
280
+ "inputs" : {
274
281
"prompt" : kwargs ["messages" ],
275
282
},
276
- output = output_data ,
277
- latency = latency ,
278
- tokens = num_of_completion_tokens ,
279
- cost = completion_cost ,
280
- prompt_tokens = None ,
281
- completion_tokens = num_of_completion_tokens ,
282
- model = kwargs .get ("model" ),
283
- model_parameters = kwargs .get ("model_parameters" ),
284
- raw_output = raw_outputs ,
285
- metadata = {
283
+ " output" : output_data ,
284
+ " latency" : latency ,
285
+ " tokens" : num_of_completion_tokens ,
286
+ " cost" : completion_cost ,
287
+ " prompt_tokens" : None ,
288
+ " completion_tokens" : num_of_completion_tokens ,
289
+ " model" : kwargs .get ("model" ),
290
+ " model_parameters" : kwargs .get ("model_parameters" ),
291
+ " raw_output" : raw_outputs ,
292
+ " metadata" : {
286
293
"timeToFirstToken" : (
287
294
(first_token_time - start_time ) * 1000
288
295
if first_token_time
289
296
else None
290
297
)
291
298
},
299
+ }
300
+ if inference_id :
301
+ trace_args ["id" ] = str (inference_id )
302
+
303
+ self ._add_to_trace (
304
+ ** trace_args ,
292
305
)
293
306
# pylint: disable=broad-except
294
307
except Exception as e :
0 commit comments