Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Zengyi.Qin authored and Zengyi.Qin committed Nov 8, 2019
1 parent 78ef998 commit e8b2390
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 733 deletions.
296 changes: 0 additions & 296 deletions keypoints/cvae/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,141 +1301,6 @@ def train_vae_keypoint(data_path,
task_name, str(step).zfill(6)))


def train_vae_action(data_path,
steps=120000,
batch_size=256,
eval_size=128,
l2_weight=1e-6,
log_step=20,
eval_step=4000,
save_step=1000,
model_path=None,
task_name='task',
optimizer='Adam'):
"""Trains the VAE for action generation. This is only for
the End-to-End baseline where the actions are directly
predicted from the visual observation.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
task_name: The name of the task to be trained on. This
is only for distinguishing the name of log files and
the saved model.
optimizer: Adam or SGDM.
Returns:
None.
"""
loader = ActionReader(data_path)

graph = build_action_training_graph()
learning_rate = tf.placeholder(tf.float32, shape=())

point_cloud_tf = graph['point_cloud_tf']
grasp_point_tf = graph['grasp_point_tf']
translation_tf = graph['translation_tf']
rotation_tf = graph['rotation_tf']

loss_vae_grasp = graph['loss_vae_grasp']
loss_vae_trans = graph['loss_vae_trans']
loss_vae_rot = graph['loss_vae_rot']
loss_vae_mmd = graph['loss_vae_mmd'] * 0.005

z_mean = graph['z_mean']
z_std = graph['z_std']

std_gt_grasp = graph['std_gt_grasp']
std_gt_trans = graph['std_gt_trans']
std_gt_rot = graph['std_gt_rot']

weight_loss = [tf.nn.l2_loss(var) for var
in tf.trainable_variables()]
weight_loss = tf.reduce_sum(weight_loss) * l2_weight

loss_vae = (loss_vae_grasp + loss_vae_trans +
loss_vae_rot + loss_vae_mmd)
loss = weight_loss + loss_vae

if optimizer == 'Adam':
train_op = tf.train.AdamOptimizer(
learning_rate=learning_rate).minimize(loss)
elif optimizer == 'SGDM':
train_op = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=0.9).minimize(loss)
else:
raise NotImplementedError

all_vars = tf.get_collection_ref(
tf.GraphKeys.GLOBAL_VARIABLES)
var_list_vae = [var for var in all_vars
if 'vae_action' in var.name and
'Momentum' not in var.name and
'Adam' not in var.name]

saver = tf.train.Saver(var_list=var_list_vae)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True

with tf.Session(config=config) as sess:
sess.run([tf.global_variables_initializer()])
if model_path:
running_log.write('vae_action_{}'.format(task_name),
'loading model from {}'.format(model_path))
saver.restore(sess, model_path)

for step in range(steps + 1):
pos_p_np, pos_a_np = loader.sample_pos_train(batch_size)

pos_grasp_np, pos_trans_np, pos_rot_np = np.split(
pos_a_np, [3, 5], axis=1)
pos_rot_np = np.concatenate([np.cos(pos_rot_np),
np.sin(pos_rot_np)], axis=1)

feed_dict = {point_cloud_tf: pos_p_np,
grasp_point_tf: pos_grasp_np,
translation_tf: pos_trans_np,
rotation_tf: pos_rot_np,
learning_rate: get_learning_rate(step, steps)}

[_, loss_np, vae_grasp, vae_trans, vae_rot, vae_mmd, weight,
std_gt_grasp_np, std_gt_trans_np, std_gt_rot_np,
z_mean_np, z_std_np] = sess.run([
train_op, loss, loss_vae_grasp,
loss_vae_trans, loss_vae_rot, loss_vae_mmd,
weight_loss, std_gt_grasp, std_gt_trans, std_gt_rot,
z_mean, z_std],
feed_dict=feed_dict)

if step % log_step == 0:
running_log.write('vae_action_{}'.format(task_name),
'step: {}/{}, '.format(step, steps) +
'loss: {:.3f}, grasp: {:.3f}/{:.3f}, '.format(
loss_np, vae_grasp, std_gt_grasp_np) +
'trans: {:.3f}/{:.3f}, '.format(
vae_trans, std_gt_trans_np) +
'rot: {:.3f}/{:.3f}, '.format(
vae_rot, std_gt_rot_np) +
'mmd: {:.3f} ({:.3f} {:.3f}), '.format(
vae_mmd, z_mean_np, z_std_np))

if step > 0 and step % save_step == 0:
makedir('./runs/vae')
saver.save(sess,
'./runs/vae/vae_action_{}_{}'.format(
task_name, str(step).zfill(6)))


def train_gcnn_grasp(data_path,
steps=60000,
batch_size=256,
Expand Down Expand Up @@ -1606,44 +1471,6 @@ def load_samples(loader, batch_size, stage, noise_level=0.2):
return p_np, grasp_np, funct_np, funct_vect_np, label_np


def load_samples_action(loader, batch_size, stage):
"""Loads the training and evaluation data.
Args:
loader: A Reader instance.
batch_size: The training or evaluation batch size.
stage: 'train' or 'val'.
Return:
p_np: A numpy array of point cloud.
grasp_np: A numpy array of grasp point.
trans_np: The translation part of the action.
rot_np: The rotation part of the action.
label: The binary success label of the actions given the
point cloud as visual observation.
"""
if stage == 'train':
pos_p_np, pos_a_np = loader.sample_pos_train(batch_size // 2)
neg_p_np, neg_a_np = loader.sample_neg_train(batch_size // 2)
elif stage == 'val':
pos_p_np, pos_a_np = loader.sample_pos_val(batch_size // 2)
neg_p_np, neg_a_np = loader.sample_neg_val(batch_size // 2)
else:
raise NotImplementedError

num_pos, num_neg = pos_p_np.shape[0], neg_p_np.shape[0]
label_np = np.concatenate(
[np.ones(shape=(num_pos, 1)),
np.zeros(shape=(num_neg, 1))],
axis=0).astype(np.float32)
p_np = np.concatenate([pos_p_np, neg_p_np], axis=0)
a_np = np.concatenate([pos_a_np, neg_a_np], axis=0)
grasp_np, trans_np, rot_np = np.split(
a_np, [3, 5], axis=1)
rot_np = np.concatenate([np.cos(rot_np), np.sin(rot_np)], axis=1)
return p_np, grasp_np, trans_np, rot_np, label_np


def train_discr_keypoint(data_path,
steps=120000,
batch_size=128,
Expand Down Expand Up @@ -1772,129 +1599,6 @@ def train_discr_keypoint(data_path,
noise_level, acc_np * 100))


def train_discr_action(data_path,
steps=120000,
batch_size=128,
eval_size=128,
l2_weight=1e-6,
log_step=20,
eval_step=1000,
save_step=1000,
model_path=None,
task_name='task',
optimizer='Adam'):
"""Trains the action evaluation network.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
task_name: The name of the task to be trained on. This
is only for distinguishing the name of log files and
the saved model.
optimizer: 'Adam' or 'SGDM'.
Returns:
None.
"""
loader = ActionReader(data_path)

graph = build_action_training_graph()
learning_rate = tf.placeholder(tf.float32, shape=())

point_cloud_tf = graph['point_cloud_tf']
grasp_point_tf = graph['grasp_point_tf']
translation_tf = graph['translation_tf']
rotation_tf = graph['rotation_tf']

actions_label_tf = graph['actions_label_tf']
loss_discr = graph['loss_discr']
acc_discr = graph['acc_discr']

weight_loss = [tf.nn.l2_loss(var) for var
in tf.trainable_variables()]
weight_loss = tf.reduce_sum(weight_loss) * l2_weight

loss = weight_loss + loss_discr

if optimizer == 'Adam':
train_op = tf.train.AdamOptimizer(
learning_rate=learning_rate).minimize(loss)
elif optimizer == 'SGDM':
train_op = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=0.9).minimize(loss)
else:
raise NotImplementedError

all_vars = tf.get_collection_ref(
tf.GraphKeys.GLOBAL_VARIABLES)
var_list_vae = [var for var in all_vars
if 'action_discriminator' in var.name and
'Momentum' not in var.name and
'Adam' not in var.name]

saver = tf.train.Saver(var_list=var_list_vae)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True

with tf.Session(config=config) as sess:
sess.run([tf.global_variables_initializer()])
if model_path:
running_log.write('discr_action_{}'.format(task_name),
'loading model from {}'.format(model_path))
saver.restore(sess, model_path)

for step in range(steps + 1):
p_np, grasp_np, trans_np, rot_np, label_np = load_samples_action(
loader, batch_size, 'train')
feed_dict = {point_cloud_tf: p_np,
grasp_point_tf: grasp_np,
translation_tf: trans_np,
rotation_tf: rot_np,
actions_label_tf: label_np,
learning_rate: get_learning_rate(step, steps)}

[_, loss_np, acc_np, weight
] = sess.run([
train_op, loss, acc_discr, weight_loss],
feed_dict=feed_dict)

if step % log_step == 0:
running_log.write('discr_action_{}'.format(task_name),
'step: {}/{}, '.format(step, steps) +
'loss: {:.3f}, acc: {:.3f}'.format(
loss_np, acc_np * 100))

if step > 0 and step % save_step == 0:
saver.save(sess,
'./runs/discr/discr_action_{}_{}'.format(
task_name, str(step).zfill(6)))

if step > 0 and step % eval_step == 0:
for noise_level in [0.1, 0.2, 0.4, 0.8]:
[p_np, grasp_np, trans_np,
rot_np, label_np] = load_samples_action(
loader, batch_size, 'train')
feed_dict = {point_cloud_tf: p_np,
grasp_point_tf: grasp_np,
translation_tf: trans_np,
rotation_tf: rot_np,
actions_label_tf: label_np}

[acc_np] = sess.run([acc_discr], feed_dict=feed_dict)
running_log.write('discr_action_{}'.format(task_name),
'noise: {:.3f}, acc: {:.3f}'.format(
noise_level, acc_np * 100))


def inference_grasp(data_path,
model_path,
batch_size=128,
Expand Down
Loading

0 comments on commit e8b2390

Please sign in to comment.