Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Zengyi.Qin authored and Zengyi.Qin committed Nov 8, 2019
1 parent e8b2390 commit 447f9dc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 204 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ Run pushing:
sh scripts/run_push_test.sh
```

Run tool generation:
```bash
cd keypoints/toolgen
```
Execute `hammer_gen.py`, `push_gen.py` or `reach_gen.py` to generate tools for three different tasks. For example, for hammering task you could run:
```bash
python hammer_gen.py
```
Results are shown in `KETO/keypoints/toolgen/visualize`.

### Train from Scratch

#### Grasping
Expand Down
198 changes: 3 additions & 195 deletions keypoints/cvae/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from scipy.spatial import ConvexHull
from sklearn.cluster import KMeans

from cvae.reader import GraspReader, KeypointReader, ActionReader
from cvae.encoder import GraspEncoder, KeypointEncoder, ActionEncoder
from cvae.decoder import GraspDecoder, KeypointDecoder, ActionDecoder
from cvae.reader import GraspReader, KeypointReader
from cvae.encoder import GraspEncoder, KeypointEncoder
from cvae.decoder import GraspDecoder, KeypointDecoder
from cvae.discriminator import GraspDiscriminator, KeypointDiscriminator
from cvae.discriminator import ActionDiscriminator

import matplotlib as mpl
mpl.use('Agg')
Expand Down Expand Up @@ -559,108 +558,6 @@ def build_keypoint_training_graph(num_points=1024,
return training_graph


def build_action_training_graph(num_points=1024):
"""Builds the tensorflow graph for training the action model.
This is only for the End-to-End baseline where we directly
predict the actions without the keypoint representations.
Args:
num_points: The number of points in one frame of point cloud.
Return:
The training graph for the action model.
"""
point_cloud_tf = tf.placeholder(dtype=tf.float32,
shape=[None, num_points, 3])
grasp_point_tf = tf.placeholder(dtype=tf.float32,
shape=[None, 3])
translation_tf = tf.placeholder(dtype=tf.float32,
shape=[None, 2])
rotation_tf = tf.placeholder(dtype=tf.float32,
shape=[None, 2])

actions = [grasp_point_tf, translation_tf, rotation_tf]
actions_label_tf = tf.placeholder(dtype=tf.float32,
shape=[None, 1])

latent_var = ActionEncoder().build_model(tf.reshape(
point_cloud_tf, (-1, num_points, 1, 3)), actions)
z_mean, z_std = tf.split(latent_var, 2, axis=1)
z = z_mean + z_std * tf.random.normal(tf.shape(z_std))
z_std = tf.reduce_mean(reduce_std(z, axis=1))
z_mean = tf.reduce_mean(z_mean)

actions_vae = ActionDecoder().build_model(tf.reshape(
point_cloud_tf, (-1, num_points, 1, 3)), latent_var)

[grasp_point_vae, translation_vae,
rotation_vae] = actions_vae

loss_vae_grasp = tf.reduce_mean(
tf.abs(grasp_point_vae - grasp_point_tf))

loss_vae_trans = tf.reduce_mean(
tf.abs(translation_vae - translation_tf))

loss_vae_rot = tf.reduce_mean(
tf.abs(rotation_vae - rotation_tf))

point_cloud_mean = tf.reduce_mean(
point_cloud_tf, axis=1)

std_gt_grasp = tf.reduce_mean(
reduce_std(grasp_point_tf -
point_cloud_mean, axis=0))
std_gt_trans = tf.reduce_mean(reduce_std(
translation_tf - point_cloud_mean[:, :2], axis=0))

std_gt_rot = tf.reduce_mean(reduce_std(
rotation_tf, axis=0))

miu, sigma = tf.split(latent_var, 2, axis=1)
loss_vae_mmd = tf.reduce_mean(tf.square(miu) +
tf.square(sigma) -
tf.log(1e-8 + tf.square(sigma)) - 1)

discr_logit = ActionDiscriminator().build_model(
tf.reshape(point_cloud_tf, (-1, num_points, 1, 3)), actions)

loss_discr = tf.nn.sigmoid_cross_entropy_with_logits(
labels=actions_label_tf, logits=discr_logit)

loss_discr = tf.reduce_mean(loss_discr)

pred_label = tf.cast(tf.greater(tf.sigmoid(discr_logit), 0.5),
tf.float32)
pred_equal_gt = tf.cast(tf.equal(pred_label,
actions_label_tf), tf.float32)

acc_discr = tf.reduce_mean(pred_equal_gt)

prec_discr = tf.math.divide(
tf.reduce_sum(pred_equal_gt * pred_label),
tf.reduce_sum(pred_label) + 1e-6)

training_graph = {'point_cloud_tf': point_cloud_tf,
'grasp_point_tf': grasp_point_tf,
'translation_tf': translation_tf,
'rotation_tf': rotation_tf,
'actions_label_tf': actions_label_tf,
'loss_vae_grasp': loss_vae_grasp,
'loss_vae_trans': loss_vae_trans,
'loss_vae_rot': loss_vae_rot,
'loss_vae_mmd': loss_vae_mmd,
'z_mean': z_mean,
'z_std': z_std,
'std_gt_grasp': std_gt_grasp,
'std_gt_trans': std_gt_trans,
'std_gt_rot': std_gt_rot,
'loss_discr': loss_discr,
'acc_discr': acc_discr,
'prec_discr': prec_discr}
return training_graph


def get_learning_rate(step, steps, init=5e-4):
"""Adjusts the learning rate in training.
Expand Down Expand Up @@ -755,41 +652,6 @@ def build_keypoint_inference_graph(num_points=1024,
return inference_graph


def build_action_inference_graph(num_points=1024,
num_samples=128):
"""Builds the inference graph for action prediction.
This is only for the End-to-End baseline where we
directly predict the actions without the keypoint
representations.
Args:
num_points: The number of points in point cloud.
num_samples: The number of samples from the action generator (VAE).
Returns:
The inference graph.
"""
point_cloud_tf = tf.placeholder(
tf.float32, [1, num_points, 3])
point_cloud = tf.tile(
point_cloud_tf, [num_samples, 1, 1])
latent_var = tf.concat(
[tf.zeros([num_samples, 2], dtype=tf.float32),
tf.ones([num_samples, 2], dtype=tf.float32)],
axis=1)
actions_vae = ActionDecoder(
).build_model(tf.reshape(point_cloud,
(-1, num_points, 1, 3)),
latent_var)

score = ActionDiscriminator().build_model(
tf.expand_dims(point_cloud, 2), actions_vae)
inference_graph = {'point_cloud_tf': point_cloud_tf,
'actions': actions_vae,
'score': score}
return inference_graph


def forward_grasp(point_cloud_tf,
grasp_keypoint,
num_points=1024,
Expand Down Expand Up @@ -963,60 +825,6 @@ def forward_keypoint(point_cloud_tf,
return top_keypoints, top_funct_vect, top_score


def forward_action(point_cloud_tf,
num_points=1024,
num_samples=256,
dist_thres=0.3):
"""A forward pass that predicts the actions from the visual
observation. This is only for the End-to-End baseline.
Args:
point_cloud_tf: A tensor of point cloud.
num_points: The number of points in the point cloud.
num_samples: The number of action candidates
produced by the action generator.
dist_thres: The maximum allowed Chamfer distance between
the grasp position and the point cloud.
Returns:
top_action: The predicted action.
top_score: The score of the action.
"""

point_cloud_tf = tf.reshape(
point_cloud_tf, [1, num_points, 3])
point_cloud = tf.tile(
point_cloud_tf, [num_samples, 1, 1])
latent_var = tf.concat(
[tf.zeros([num_samples, 2], dtype=tf.float32),
tf.ones([num_samples, 2], dtype=tf.float32)],
axis=1)
actions_vae = ActionDecoder(
).build_model(
tf.reshape(point_cloud, (-1, num_points, 1, 3)),
latent_var)

grasp = tf.reshape(actions_vae[0], [num_samples, 1, 3])
dist = tf.reduce_min(
tf.linalg.norm(grasp - point_cloud, axis=2), axis=1)
dist_min = tf.reduce_min(dist)

mask = tf.logical_or(
tf.less(dist, dist_thres),
tf.equal(dist, dist_min))
actions_vae = [tf.boolean_mask(a, mask, axis=0) for a in actions_vae]
point_cloud = tf.boolean_mask(point_cloud, mask, axis=0)

score = ActionDiscriminator().build_model(
tf.expand_dims(point_cloud, 2), actions_vae)

index = tf.argmax(tf.reshape(score, [-1]), 0)
top_score = score[index]
top_actions = [tf.expand_dims(a[index], 0) for a in actions_vae]

return top_actions, top_score


def train_vae_grasp(data_path,
steps=60000,
batch_size=256,
Expand Down
9 changes: 0 additions & 9 deletions keypoints/cvae/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,3 @@ def build_model(self, x, ks, v=None):
sigma_sp = tf.nn.softplus(sigma)
z = tf.concat([miu, sigma_sp], axis=1)
return z

x = self.fc_layer(x, 256, name='fc1')
x = self.fc_layer(x, 256, name='fc2')
z = self.fc_layer(x, 4, linear=True, name='out')

miu, sigma = tf.split(z, 2, axis=1)
sigma_sp = tf.nn.softplus(sigma)
z = tf.concat([miu, sigma_sp], axis=1)
return z

0 comments on commit 447f9dc

Please sign in to comment.