Skip to content

Commit 878d3ab

Browse files
committed
Add example of saving GraphDef and variables of a converted model.
1 parent a8120f3 commit 878d3ab

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

examples/save_model/.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
checkpoint
2+
3+
*.index
4+
*.meta
5+
*.pb
6+
*.pbtxt
7+
*.data-*-of-*
8+
9+
*.py
10+
!save_model.py
11+
!__init__.py

examples/save_model/READMD.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Example of Saving Model
2+
3+
This is an example of how to save the GraphDef and variables of a converted model
4+
in the Tensorflow official form. By doing this, the converted model can be
5+
conveniently applied with Tensorflow APIs in other languages.
6+
7+
For example, if a converted model is named "VGG", the generated code file should
8+
be named as "VGG.py", and the class name inside should remain "CaffeNet".
9+
10+
The module "VGG" should be able to be directly imported. So put it inside the
11+
[save_graphdef](save_graphdef) folder, or add it to "sys.path".
12+
13+
To save model variables, pass the path of the converted data file (e.g. VGG.npy)
14+
to the parameter "--data-input-path".
15+
16+
A "VGG_frozen.pb' is also generated with all variables converted into constants
17+
in the saved graph.

examples/save_model/__init__.py

Whitespace-only changes.

examples/save_model/save_model.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import os.path as osp
5+
import sys
6+
7+
import tensorflow as tf
8+
from tensorflow.python.tools import freeze_graph
9+
from tensorflow.python.training import saver as saver_lib
10+
11+
12+
def save(name, data_input_path):
13+
def getpardir(path): return osp.split(path)[0]
14+
sys.path.append(getpardir(getpardir(getpardir(osp.realpath(__file__)))))
15+
# Import the converted model's class
16+
caffe_net_module = __import__(name)
17+
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
18+
image_input = tf.placeholder(tf.float32, shape=[1, 227, 227, 3], name="data")
19+
net = caffe_net_module.CaffeNet({'data': image_input})
20+
21+
# Save protocol buffer
22+
pb_name = name + '.pb'
23+
tf.train.write_graph(sess.graph_def, '.', pb_name + 'txt', True)
24+
tf.train.write_graph(sess.graph_def, '.', pb_name, False)
25+
26+
if data_input_path is not None:
27+
# Load the data
28+
sess.run(tf.global_variables_initializer())
29+
net.load(data_input_path, sess)
30+
# Save the data
31+
saver = saver_lib.Saver(tf.global_variables())
32+
checkpoint_prefix = osp.join(osp.curdir, name + '.ckpt')
33+
checkpoint_path = saver.save(sess, checkpoint_prefix)
34+
35+
# Freeze the graph
36+
freeze_graph.freeze_graph(pb_name, "",
37+
True, checkpoint_path, 'fc8/fc8',
38+
'save/restore_all', 'save/Const:0',
39+
name + '_frozen.pb', False, "")
40+
41+
42+
def main():
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument('name', help='Name of the converted model')
45+
parser.add_argument('--data-input-path', help='Converted data input path')
46+
args = parser.parse_args()
47+
save(args.name, args.data_input_path)
48+
49+
50+
if __name__ == '__main__':
51+
main()

0 commit comments

Comments
 (0)