@@ -136,9 +136,13 @@ def create_prediction_request(transformed_sample):
136
136
shape = [1 ]
137
137
if util .is_list (value ):
138
138
shape = [len (value )]
139
- tensor_proto = tf .make_tensor_proto ([value ], dtype = data_type , shape = shape )
140
- prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
141
-
139
+ try :
140
+ tensor_proto = tf .make_tensor_proto ([value ], dtype = data_type , shape = shape )
141
+ prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
142
+ except Exception as e :
143
+ raise UserException (
144
+ 'key "{}"' .format (column_name ), "expected shape {}" .format (shape )
145
+ ) from e
142
146
return prediction_request
143
147
144
148
@@ -160,8 +164,15 @@ def create_raw_prediction_request(sample):
160
164
shape = [1 ]
161
165
value = [value ]
162
166
sig_type = signature_def [signature_key ]["inputs" ][column_name ]["dtype" ]
163
- tensor_proto = tf .make_tensor_proto (value , dtype = DTYPE_TO_TF_TYPE [sig_type ], shape = shape )
164
- prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
167
+ try :
168
+ tensor_proto = tf .make_tensor_proto (
169
+ value , dtype = DTYPE_TO_TF_TYPE [sig_type ], shape = shape
170
+ )
171
+ prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
172
+ except Exception as e :
173
+ raise UserException (
174
+ 'key "{}"' .format (column_name ), "expected shape {}" .format (shape )
175
+ ) from e
165
176
166
177
return prediction_request
167
178
@@ -248,7 +259,7 @@ def create_get_model_metadata_request():
248
259
249
260
def run_get_model_metadata ():
250
261
request = create_get_model_metadata_request ()
251
- resp = local_cache ["stub" ].GetModelMetadata (request , timeout = 10 .0 )
262
+ resp = local_cache ["stub" ].GetModelMetadata (request , timeout = 30 .0 )
252
263
sigAny = resp .metadata ["signature_def" ]
253
264
signature_def_map = get_model_metadata_pb2 .SignatureDefMap ()
254
265
sigAny .Unpack (signature_def_map )
@@ -272,14 +283,11 @@ def run_predict(sample):
272
283
ctx = local_cache ["ctx" ]
273
284
request_handler = local_cache .get ("request_handler" )
274
285
275
- logger .info ("sample: " + util .pp_str_flat (sample ))
276
-
277
286
prepared_sample = sample
278
287
if request_handler is not None and util .has_function (request_handler , "pre_inference" ):
279
288
prepared_sample = request_handler .pre_inference (
280
289
sample , local_cache ["metadata" ]["signatureDef" ]
281
290
)
282
- logger .info ("pre_inference: " + util .pp_str_flat (prepared_sample ))
283
291
284
292
validate_sample (prepared_sample )
285
293
@@ -291,24 +299,18 @@ def run_predict(sample):
291
299
)
292
300
293
301
transformed_sample = transform_sample (prepared_sample )
294
- logger .info ("transformed_sample: " + util .pp_str_flat (transformed_sample ))
295
-
296
302
prediction_request = create_prediction_request (transformed_sample )
297
- response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 10 .0 )
303
+ response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 100 .0 )
298
304
result = parse_response_proto (response_proto )
299
305
300
306
result ["transformed_sample" ] = transformed_sample
301
- logger .info ("inference: " + util .pp_str_flat (result ))
302
307
else :
303
308
prediction_request = create_raw_prediction_request (prepared_sample )
304
- response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 10 .0 )
309
+ response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 100 .0 )
305
310
result = parse_response_proto_raw (response_proto )
306
311
307
- logger .info ("inference: " + util .pp_str_flat (result ))
308
-
309
312
if request_handler is not None and util .has_function (request_handler , "post_inference" ):
310
313
result = request_handler .post_inference (result , local_cache ["metadata" ]["signatureDef" ])
311
- logger .info ("post_inference: " + util .pp_str_flat (result ))
312
314
313
315
return result
314
316
@@ -335,10 +337,8 @@ def validate_sample(sample):
335
337
raise UserException ('missing key "{}"' .format (input_name ))
336
338
337
339
338
- def prediction_failed (sample , reason = None ):
339
- message = "prediction failed for sample: {}" .format (util .pp_str_flat (sample ))
340
- if reason :
341
- message += " ({})" .format (reason )
340
+ def prediction_failed (reason ):
341
+ message = "prediction failed: " + reason
342
342
343
343
logger .error (message )
344
344
return message , status .HTTP_406_NOT_ACCEPTABLE
@@ -363,16 +363,12 @@ def predict(deployment_name, api_name):
363
363
response = {}
364
364
365
365
if not util .is_dict (payload ) or "samples" not in payload :
366
- util .log_pretty_flat (payload , logging_func = logger .error )
367
- return prediction_failed (payload , "top level `samples` key not found in request" )
366
+ return prediction_failed ('top level "samples" key not found in request' )
368
367
369
368
predictions = []
370
369
samples = payload ["samples" ]
371
370
if not util .is_list (samples ):
372
- util .log_pretty_flat (samples , logging_func = logger .error )
373
- return prediction_failed (
374
- payload , "expected the value of key `samples` to be a list of json objects"
375
- )
371
+ return prediction_failed ('expected the value of key "samples" to be a list of json objects' )
376
372
377
373
for i , sample in enumerate (payload ["samples" ]):
378
374
try :
@@ -385,14 +381,14 @@ def predict(deployment_name, api_name):
385
381
api ["name" ]
386
382
)
387
383
)
388
- return prediction_failed (sample , str (e ))
384
+ return prediction_failed (str (e ))
389
385
except Exception as e :
390
386
logger .exception (
391
387
"An error occurred, see `cortex logs -v api {}` for more details." .format (
392
388
api ["name" ]
393
389
)
394
390
)
395
- return prediction_failed (sample , str (e ))
391
+ return prediction_failed (str (e ))
396
392
397
393
predictions .append (result )
398
394
0 commit comments