@@ -155,7 +155,14 @@ def wrapper(*func_args, **func_kwargs):
155
155
if step_kwargs .get ("name" ) is None :
156
156
step_kwargs ["name" ] = func .__name__
157
157
with create_step (* step_args , ** step_kwargs ) as step :
158
- output = func (* func_args , ** func_kwargs )
158
+ output = None
159
+ exception = None
160
+ try :
161
+ output = func (* func_args , ** func_kwargs )
162
+ # pylint: disable=broad-except
163
+ except Exception as exc :
164
+ step .log (metadata = {"Exceptions" : str (exc )})
165
+ exception = exc
159
166
end_time = time .time ()
160
167
latency = (end_time - step .start_time ) * 1000 # in ms
161
168
@@ -171,6 +178,9 @@ def wrapper(*func_args, **func_kwargs):
171
178
end_time = end_time ,
172
179
latency = latency ,
173
180
)
181
+
182
+ if exception is not None :
183
+ raise exception
174
184
return output
175
185
176
186
return wrapper
@@ -189,12 +199,14 @@ def process_trace_for_upload(
189
199
root_step = trace_obj .steps [0 ]
190
200
191
201
input_variables = root_step .inputs
192
- input_variable_names = list (input_variables .keys ())
202
+ if input_variables :
203
+ input_variable_names = list (input_variables .keys ())
204
+ else :
205
+ input_variable_names = []
193
206
194
207
processed_steps = bubble_up_costs_and_tokens (trace_obj .to_dict ())
195
208
196
209
trace_data = {
197
- ** input_variables ,
198
210
"inferenceTimestamp" : root_step .start_time ,
199
211
"inferenceId" : str (root_step .id ),
200
212
"output" : root_step .output ,
@@ -204,6 +216,8 @@ def process_trace_for_upload(
204
216
"tokens" : processed_steps [0 ].get ("tokens" , 0 ),
205
217
"steps" : processed_steps ,
206
218
}
219
+ if input_variables :
220
+ trace_data .update (input_variables )
207
221
208
222
return trace_data , input_variable_names
209
223
0 commit comments