-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge.py
73 lines (60 loc) · 2.37 KB
/
merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Merges the keypoints model and the grasping model."""
import tensorflow as tf
import argparse
from cvae.build import build_grasp_inference_graph
from cvae.build import build_keypoint_inference_graph
parser = argparse.ArgumentParser()
parser.add_argument('--model',
type=str,
required=True)
parser.add_argument('--vae',
type=str)
parser.add_argument('--discr',
type=str)
parser.add_argument('--grasp',
type=str)
parser.add_argument('--keypoint',
type=str)
parser.add_argument('--num_funct_vect',
type=str,
default='1')
parser.add_argument('--output',
type=str,
default='./runs/cvae_model')
args = parser.parse_args()
if args.model == 'grasp':
build_grasp_inference_graph()
elif args.model == 'keypoint':
build_keypoint_inference_graph(
num_funct_vect=int(args.num_funct_vect))
elif args.model == 'grasp_keypoint':
build_grasp_inference_graph()
build_keypoint_inference_graph(
num_funct_vect=int(args.num_funct_vect))
else:
raise ValueError(args.model)
with tf.Session() as sess:
vars = tf.global_variables()
if args.model in ['grasp', 'keypoint']:
# Merges the generation network (VAE) and
# the evaluation network (binary classifier).
vars_vae = [var for var in vars if 'vae' in var.name]
vars_discr = [var for var in vars if 'discr' in var.name]
saver = tf.train.Saver(var_list=vars)
saver_vae = tf.train.Saver(var_list=vars_vae)
saver_discr = tf.train.Saver(var_list=vars_discr)
saver_vae.restore(sess, args.vae)
saver_discr.restore(sess, args.discr)
saver.save(sess, args.output)
elif args.model == 'grasp_keypoint':
# Merges the grasp prediction network and the keypoints network
vars_grasp = [var for var in vars if 'grasp' in var.name]
vars_keypoint = [var for var in vars if 'keypoint' in var.name]
saver = tf.train.Saver(var_list=vars)
saver_grasp = tf.train.Saver(var_list=vars_grasp)
saver_keypoint = tf.train.Saver(var_list=vars_keypoint)
saver_grasp.restore(sess, args.grasp)
saver_keypoint.restore(sess, args.keypoint)
saver.save(sess, args.output)
else:
raise ValueError