|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +import argparse |
| 4 | +import os.path as osp |
| 5 | +import sys |
| 6 | + |
| 7 | +import tensorflow as tf |
| 8 | +from tensorflow.python.tools import freeze_graph |
| 9 | +from tensorflow.python.training import saver as saver_lib |
| 10 | + |
| 11 | + |
| 12 | +def save(name, data_input_path): |
| 13 | + def getpardir(path): return osp.split(path)[0] |
| 14 | + sys.path.append(getpardir(getpardir(getpardir(osp.realpath(__file__))))) |
| 15 | + # Import the converted model's class |
| 16 | + caffe_net_module = __import__(name) |
| 17 | + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: |
| 18 | + image_input = tf.placeholder(tf.float32, shape=[1, 227, 227, 3], name="data") |
| 19 | + net = caffe_net_module.CaffeNet({'data': image_input}) |
| 20 | + |
| 21 | + # Save protocol buffer |
| 22 | + pb_name = name + '.pb' |
| 23 | + tf.train.write_graph(sess.graph_def, '.', pb_name + 'txt', True) |
| 24 | + tf.train.write_graph(sess.graph_def, '.', pb_name, False) |
| 25 | + |
| 26 | + if data_input_path is not None: |
| 27 | + # Load the data |
| 28 | + sess.run(tf.global_variables_initializer()) |
| 29 | + net.load(data_input_path, sess) |
| 30 | + # Save the data |
| 31 | + saver = saver_lib.Saver(tf.global_variables()) |
| 32 | + checkpoint_prefix = osp.join(osp.curdir, name + '.ckpt') |
| 33 | + checkpoint_path = saver.save(sess, checkpoint_prefix) |
| 34 | + |
| 35 | + # Freeze the graph |
| 36 | + freeze_graph.freeze_graph(pb_name, "", |
| 37 | + True, checkpoint_path, 'fc8/fc8', |
| 38 | + 'save/restore_all', 'save/Const:0', |
| 39 | + name + '_frozen.pb', False, "") |
| 40 | + |
| 41 | + |
| 42 | +def main(): |
| 43 | + parser = argparse.ArgumentParser() |
| 44 | + parser.add_argument('name', help='Name of the converted model') |
| 45 | + parser.add_argument('--data-input-path', help='Converted data input path') |
| 46 | + args = parser.parse_args() |
| 47 | + save(args.name, args.data_input_path) |
| 48 | + |
| 49 | + |
| 50 | +if __name__ == '__main__': |
| 51 | + main() |
0 commit comments