Skip to content

Commit b793524

Browse files
1vndeliahu
authored andcommitted
Timeout fix and remove logging (#317)
1 parent 70004b1 commit b793524

File tree

2 files changed

+33
-47
lines changed

2 files changed

+33
-47
lines changed

pkg/workloads/cortex/onnx_serve/api.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,8 @@
6666
}
6767

6868

69-
def prediction_failed(sample, reason=None):
70-
message = "prediction failed for sample: {}".format(util.pp_str_flat(sample))
71-
if reason:
72-
message += " ({})".format(reason)
69+
def prediction_failed(reason):
70+
message = "prediction failed: " + reason
7371

7472
logger.error(message)
7573
return message, status.HTTP_406_NOT_ACCEPTABLE
@@ -113,7 +111,7 @@ def convert_to_onnx_input(sample, input_metadata_list):
113111
try:
114112
input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata)
115113
except CortexException as e:
116-
e.wrap("key {}".format(input_metadata.name))
114+
e.wrap('key "{}"'.format(input_metadata.name))
117115
raise
118116
else:
119117
for input_metadata in input_metadata_list:
@@ -130,7 +128,7 @@ def convert_to_onnx_input(sample, input_metadata_list):
130128
try:
131129
input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata)
132130
except CortexException as e:
133-
e.wrap("key {}".format(input_metadata.name))
131+
e.wrap('key "{}"'.format(input_metadata.name))
134132
raise
135133
return input_dict
136134

@@ -151,24 +149,18 @@ def predict(app_name, api_name):
151149
response = {}
152150

153151
if not util.is_dict(payload) or "samples" not in payload:
154-
util.log_pretty_flat(payload, logging_func=logger.error)
155-
return prediction_failed(payload, "top level `samples` key not found in request")
152+
return prediction_failed('top level "samples" key not found in request')
156153

157154
predictions = []
158155
samples = payload["samples"]
159156
if not util.is_list(samples):
160-
util.log_pretty_flat(samples, logging_func=logger.error)
161-
return prediction_failed(
162-
payload, "expected the value of key `samples` to be a list of json objects"
163-
)
157+
return prediction_failed('expected the value of key "samples" to be a list of json objects')
164158

165159
for i, sample in enumerate(payload["samples"]):
166160
try:
167-
logger.info("sample: " + util.pp_str_flat(sample))
168161
prepared_sample = sample
169162
if request_handler is not None and util.has_function(request_handler, "pre_inference"):
170163
prepared_sample = request_handler.pre_inference(sample, input_metadata)
171-
logger.info("pre_inference: " + util.pp_str_flat(prepared_sample))
172164

173165
inference_input = convert_to_onnx_input(prepared_sample, input_metadata)
174166
model_outputs = sess.run([], inference_input)
@@ -179,10 +171,8 @@ def predict(app_name, api_name):
179171
else:
180172
result.append(model_output)
181173

182-
logger.info("inference: " + util.pp_str_flat(result))
183174
if request_handler is not None and util.has_function(request_handler, "post_inference"):
184175
result = request_handler.post_inference(result, output_metadata)
185-
logger.info("post_inference: " + util.pp_str_flat(result))
186176

187177
prediction = {"prediction": result}
188178
except CortexException as e:
@@ -191,12 +181,12 @@ def predict(app_name, api_name):
191181
logger.exception(
192182
"An error occurred, see `cx logs -v api {}` for more details.".format(api["name"])
193183
)
194-
return prediction_failed(sample, str(e))
184+
return prediction_failed(str(e))
195185
except Exception as e:
196186
logger.exception(
197187
"An error occurred, see `cx logs -v api {}` for more details.".format(api["name"])
198188
)
199-
return prediction_failed(sample, str(e))
189+
return prediction_failed(str(e))
200190

201191
predictions.append(prediction)
202192

pkg/workloads/cortex/tf_api/api.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,13 @@ def create_prediction_request(transformed_sample):
136136
shape = [1]
137137
if util.is_list(value):
138138
shape = [len(value)]
139-
tensor_proto = tf.make_tensor_proto([value], dtype=data_type, shape=shape)
140-
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
141-
139+
try:
140+
tensor_proto = tf.make_tensor_proto([value], dtype=data_type, shape=shape)
141+
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
142+
except Exception as e:
143+
raise UserException(
144+
'key "{}"'.format(column_name), "expected shape {}".format(shape)
145+
) from e
142146
return prediction_request
143147

144148

@@ -160,8 +164,15 @@ def create_raw_prediction_request(sample):
160164
shape = [1]
161165
value = [value]
162166
sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"]
163-
tensor_proto = tf.make_tensor_proto(value, dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape)
164-
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
167+
try:
168+
tensor_proto = tf.make_tensor_proto(
169+
value, dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape
170+
)
171+
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
172+
except Exception as e:
173+
raise UserException(
174+
'key "{}"'.format(column_name), "expected shape {}".format(shape)
175+
) from e
165176

166177
return prediction_request
167178

@@ -248,7 +259,7 @@ def create_get_model_metadata_request():
248259

249260
def run_get_model_metadata():
250261
request = create_get_model_metadata_request()
251-
resp = local_cache["stub"].GetModelMetadata(request, timeout=10.0)
262+
resp = local_cache["stub"].GetModelMetadata(request, timeout=30.0)
252263
sigAny = resp.metadata["signature_def"]
253264
signature_def_map = get_model_metadata_pb2.SignatureDefMap()
254265
sigAny.Unpack(signature_def_map)
@@ -272,14 +283,11 @@ def run_predict(sample):
272283
ctx = local_cache["ctx"]
273284
request_handler = local_cache.get("request_handler")
274285

275-
logger.info("sample: " + util.pp_str_flat(sample))
276-
277286
prepared_sample = sample
278287
if request_handler is not None and util.has_function(request_handler, "pre_inference"):
279288
prepared_sample = request_handler.pre_inference(
280289
sample, local_cache["metadata"]["signatureDef"]
281290
)
282-
logger.info("pre_inference: " + util.pp_str_flat(prepared_sample))
283291

284292
validate_sample(prepared_sample)
285293

@@ -291,24 +299,18 @@ def run_predict(sample):
291299
)
292300

293301
transformed_sample = transform_sample(prepared_sample)
294-
logger.info("transformed_sample: " + util.pp_str_flat(transformed_sample))
295-
296302
prediction_request = create_prediction_request(transformed_sample)
297-
response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0)
303+
response_proto = local_cache["stub"].Predict(prediction_request, timeout=100.0)
298304
result = parse_response_proto(response_proto)
299305

300306
result["transformed_sample"] = transformed_sample
301-
logger.info("inference: " + util.pp_str_flat(result))
302307
else:
303308
prediction_request = create_raw_prediction_request(prepared_sample)
304-
response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0)
309+
response_proto = local_cache["stub"].Predict(prediction_request, timeout=100.0)
305310
result = parse_response_proto_raw(response_proto)
306311

307-
logger.info("inference: " + util.pp_str_flat(result))
308-
309312
if request_handler is not None and util.has_function(request_handler, "post_inference"):
310313
result = request_handler.post_inference(result, local_cache["metadata"]["signatureDef"])
311-
logger.info("post_inference: " + util.pp_str_flat(result))
312314

313315
return result
314316

@@ -335,10 +337,8 @@ def validate_sample(sample):
335337
raise UserException('missing key "{}"'.format(input_name))
336338

337339

338-
def prediction_failed(sample, reason=None):
339-
message = "prediction failed for sample: {}".format(util.pp_str_flat(sample))
340-
if reason:
341-
message += " ({})".format(reason)
340+
def prediction_failed(reason):
341+
message = "prediction failed: " + reason
342342

343343
logger.error(message)
344344
return message, status.HTTP_406_NOT_ACCEPTABLE
@@ -363,16 +363,12 @@ def predict(deployment_name, api_name):
363363
response = {}
364364

365365
if not util.is_dict(payload) or "samples" not in payload:
366-
util.log_pretty_flat(payload, logging_func=logger.error)
367-
return prediction_failed(payload, "top level `samples` key not found in request")
366+
return prediction_failed('top level "samples" key not found in request')
368367

369368
predictions = []
370369
samples = payload["samples"]
371370
if not util.is_list(samples):
372-
util.log_pretty_flat(samples, logging_func=logger.error)
373-
return prediction_failed(
374-
payload, "expected the value of key `samples` to be a list of json objects"
375-
)
371+
return prediction_failed('expected the value of key "samples" to be a list of json objects')
376372

377373
for i, sample in enumerate(payload["samples"]):
378374
try:
@@ -385,14 +381,14 @@ def predict(deployment_name, api_name):
385381
api["name"]
386382
)
387383
)
388-
return prediction_failed(sample, str(e))
384+
return prediction_failed(str(e))
389385
except Exception as e:
390386
logger.exception(
391387
"An error occurred, see `cortex logs -v api {}` for more details.".format(
392388
api["name"]
393389
)
394390
)
395-
return prediction_failed(sample, str(e))
391+
return prediction_failed(str(e))
396392

397393
predictions.append(result)
398394

0 commit comments

Comments
 (0)