Skip to content

How to use code with onnx? #211

Open
@za13

Description

@za13

I'm using the code in the notebook in the "examples"

I'm trying to use keras2onnx with it also. I tried

import keras2onnx
import onnxruntime

# convert to onnx model
onnx_model = keras2onnx.convert_keras(model, model.name)

# runtime prediction
content = onnx_model.SerializeToString()
sess = onnxruntime.InferenceSession(content)
x = x if isinstance(x, list) else [x]
feed = dict([(input.name, x[n]) for n, input in enumerate(sess.get_inputs())])
pred_onnx = sess.run(None, feed)

But I got the following error:

InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    426         results = c_api.TF_GraphImportGraphDefWithResults(
--> 427             graph._c_graph, serialized, options)  # pylint: disable=protected-access
    428         results = c_api_util.ScopedTFImportGraphDefResults(results)

InvalidArgumentError: Node 'block1b_drop/cond/mul/y': Unknown input node '^block1b_drop/cond/switch_t'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-16-a5d3ba8b443c> in <module>
      3 
      4 # convert to onnx model
----> 5 onnx_model = keras2onnx.convert_keras(model, model.name)
      6 
      7 # runtime prediction

~/anaconda3/lib/python3.6/site-packages/keras2onnx/main.py in convert_keras(model, name, doc_string, target_opset, channel_first_inputs, debug_mode, custom_op_conversions)
     98                         custom_op_dict=custom_op_conversions)
     99     topology.debug_mode = debug_mode
--> 100     parse_graph(topology, sess.graph, target_opset, output_names)
    101     topology.compile()
    102 

~/anaconda3/lib/python3.6/site-packages/keras2onnx/parser.py in parse_graph(topo, graph, target_opset, output_names)
    647         topo.raw_model.add_input_name(str_value)
    648 
--> 649     return _parse_graph_scope(graph, keras_layer_ts_map, topo, top_level, output_names)

~/anaconda3/lib/python3.6/site-packages/keras2onnx/parser.py in _parse_graph_scope(graph, keras_node_dict, topology, top_scope, output_names)
    597             _convert_keras_timedistributed(graph, nodes, layer_key_, model_, varset)
    598         elif layer_key_ is None or get_converter(type(layer_key_)) is None:
--> 599             _convert_general_scope(nodes, varset)
    600         else:
    601             _convert_keras_scope(graph, nodes, layer_key_, model_, varset)

~/anaconda3/lib/python3.6/site-packages/keras2onnx/parser.py in _convert_general_scope(node_list, varset)
    299 
    300     sess = keras.backend.get_session()
--> 301     subgraph, replacement = create_subgraph(sess.graph, node_list, sess, operator.full_name)
    302     setattr(operator, 'subgraph', subgraph)
    303     vars_, ts = _locate_inputs_by_node(node_list, varset)

~/anaconda3/lib/python3.6/site-packages/keras2onnx/subgraph.py in create_subgraph(tf_graph, node_list, sess, dst_scope)
    135     with tf.Graph().as_default() as sub_graph:
    136         im_scope = "" if dst_scope is None else dst_scope
--> 137         tf.import_graph_def(output_graph_def, name=im_scope)
    138         if im_scope:
    139             replacement = {k_: im_scope + '/' + k_ for k_ in replacement}

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    429       except errors.InvalidArgumentError as e:
    430         # Convert to ValueError for backwards compatibility.
--> 431         raise ValueError(str(e))
    432 
    433     # Create _DefinedFunctions for any imported functions.

ValueError: Node 'block1b_drop/cond/mul/y': Unknown input node '^block1b_drop/cond/switch_t'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions