Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Zengyi.Qin authored and Zengyi.Qin committed Nov 8, 2019
1 parent d12fd84 commit 78ef998
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 63 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ pip install -r requirements.txt

### Quick Demo

Run pushing:
Run hammering:
```bash
sh scripts/run_push_test.sh
sh scripts/run_hammer_test.sh
```

Run reaching:
```bash
sh scripts/run_reach_test.sh
```

Run hammering:
Run pushing:
```bash
sh scripts/run_hammer_test.sh
sh scripts/run_push_test.sh
```

### Train from Scratch
Expand Down
208 changes: 204 additions & 4 deletions keypoints/cvae/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,13 @@ def rectify_keypoints(point_cloud,
point_cloud: A numpy array of point cloud.
grasp_point: The grasp point.
funct_point: The function point.
funct_on_hull: A bool indicating whether the function point
should be on the convex hull of the clustering centers.
grasp_clusters: The number of clusters for rectifying the grasp point.
funct_clusters: The number of clusters for rectifying the function point.
funct_on_hull: A bool indicating whether the
function point should be on the convex
hull of the clustering centers.
grasp_clusters: The number of clusters for
rectifying the grasp point.
funct_clusters: The number of clusters for
rectifying the function point.
Returns:
grasp_point: The rectified grasp point.
Expand Down Expand Up @@ -881,14 +884,21 @@ def forward_keypoint(point_cloud_tf,
num_points: The number of points in the point cloud.
num_samples: The number of keypoint candidates
produced by the keypoint generator.
dist_thres: The maximum allowed Chamfer distance
between the keypoints and the point cloud.
num_funct_vect: The number of vectors pointing from
the function point to the effect point.
funct_on_hull: Whether the function point should on
the convex hull of the cluster centers of the
point cloud.
Returns:
top_keypoints: The predicted grasp point and function point.
top_funct_vect: The predicted vector pointing from the function point to
the effect point.
top_score: The score of the keypoints.
"""

point_cloud_tf = tf.reshape(
point_cloud_tf, [1, num_points, 3])
point_cloud = tf.tile(
Expand Down Expand Up @@ -957,6 +967,22 @@ def forward_action(point_cloud_tf,
num_points=1024,
num_samples=256,
dist_thres=0.3):
"""A forward pass that predicts the actions from the visual
observation. This is only for the End-to-End baseline.
Args:
point_cloud_tf: A tensor of point cloud.
num_points: The number of points in the point cloud.
num_samples: The number of action candidates
produced by the action generator.
dist_thres: The maximum allowed Chamfer distance between
the grasp position and the point cloud.
Returns:
top_action: The predicted action.
top_score: The score of the action.
"""

point_cloud_tf = tf.reshape(
point_cloud_tf, [1, num_points, 3])
point_cloud = tf.tile(
Expand Down Expand Up @@ -1001,6 +1027,23 @@ def train_vae_grasp(data_path,
save_step=6000,
model_path=None,
optimizer='Adam'):
"""Trains the VAE for grasp generation.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
optimizer: Adam or SGDM.
Returns:
None.
"""
loader = GraspReader(data_path)
graph = build_grasp_training_graph()
pose_rot_train = graph['pose_rot']
Expand Down Expand Up @@ -1135,6 +1178,26 @@ def train_vae_keypoint(data_path,
model_path=None,
task_name='task',
optimizer='Adam'):
"""Trains the VAE for keypoint generation.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
task_name: The name of the task to be trained on. This
is only for distinguishing the name of log files and
the saved model.
optimizer: Adam or SGDM.
Returns:
None.
"""
loader = KeypointReader(data_path)
num_funct_vect = loader.num_funct_vect

Expand Down Expand Up @@ -1249,6 +1312,28 @@ def train_vae_action(data_path,
model_path=None,
task_name='task',
optimizer='Adam'):
"""Trains the VAE for action generation. This is only for
the End-to-End baseline where the actions are directly
predicted from the visual observation.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
task_name: The name of the task to be trained on. This
is only for distinguishing the name of log files and
the saved model.
optimizer: Adam or SGDM.
Returns:
None.
"""
loader = ActionReader(data_path)

graph = build_action_training_graph()
Expand Down Expand Up @@ -1362,6 +1447,24 @@ def train_gcnn_grasp(data_path,
model_path=None,
optimizer='SGDM',
lr_init=8e-4):
"""Trains the grasp evaluation network.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
optimizer: Adam or SGDM.
lr_init: The initial learning rate.
Returns:
None.
"""

loader = GraspReader(data_path)
graph = build_grasp_training_graph()
Expand Down Expand Up @@ -1464,6 +1567,23 @@ def train_gcnn_grasp(data_path,


def load_samples(loader, batch_size, stage, noise_level=0.2):
"""Load training and evaluation data.
Args:
loader: The Reader instance.
batch_size: The batch size for training or evaluation.
stage: 'train' or 'val'.
noise_level: The noise to be added to the negative examples.
Returns:
p_np: A numpy array of point cloud.
grasp_np: A numpy array of grasp point.
funct_np: A numpy array of function point.
funct_vect_np: A numpy array of the vector pointing from the
function point to the effect point.
label: The binary success label of the keypoints given the
point cloud as visual observation.
"""
if stage == 'train':
pos_p_np, pos_k_np = loader.sample_pos_train(batch_size // 2)
neg_p_np, neg_k_np = loader.sample_neg_train(batch_size // 2)
Expand All @@ -1487,6 +1607,21 @@ def load_samples(loader, batch_size, stage, noise_level=0.2):


def load_samples_action(loader, batch_size, stage):
"""Loads the training and evaluation data.
Args:
loader: A Reader instance.
batch_size: The training or evaluation batch size.
stage: 'train' or 'val'.
Return:
p_np: A numpy array of point cloud.
grasp_np: A numpy array of grasp point.
trans_np: The translation part of the action.
rot_np: The rotation part of the action.
label: The binary success label of the actions given the
point cloud as visual observation.
"""
if stage == 'train':
pos_p_np, pos_a_np = loader.sample_pos_train(batch_size // 2)
neg_p_np, neg_a_np = loader.sample_neg_train(batch_size // 2)
Expand Down Expand Up @@ -1520,6 +1655,26 @@ def train_discr_keypoint(data_path,
model_path=None,
task_name='task',
optimizer='Adam'):
"""Trains the keypoint evaluation network.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
task_name: The name of the task to be trained on. This
is only for distinguishing the name of log files and
the saved model.
optimizer: 'Adam' or 'SGDM'.
Returns:
None.
"""
loader = KeypointReader(data_path)
num_funct_vect = loader.num_funct_vect

Expand Down Expand Up @@ -1628,6 +1783,26 @@ def train_discr_action(data_path,
model_path=None,
task_name='task',
optimizer='Adam'):
"""Trains the action evaluation network.
Args:
data_path: The training data in a single h5df file.
steps: The total number of training steps.
batch_size: The training batch size.
eval_size: The evaluation batch size.
l2_weight: The L2 regularization weight.
log_step: The interval for logging.
eval_step: The interval for evaluation.
save_step: The interval for saving the model weights.
model_path: The pretrained model.
task_name: The name of the task to be trained on. This
is only for distinguishing the name of log files and
the saved model.
optimizer: 'Adam' or 'SGDM'.
Returns:
None.
"""
loader = ActionReader(data_path)

graph = build_action_training_graph()
Expand Down Expand Up @@ -1726,6 +1901,20 @@ def inference_grasp(data_path,
score_thres=0.7,
dist_thres=0.2,
show_best=True):
"""Predicts the grasps offline.
Args:
data_path: The point cloud data in hdf5.
model_path: The pre-trained model.
batch_size: The batch size in inference.
score_thres: The minimum allowed score of the predicted grasps.
dist_thres: The maximum allowed Chamfer distance between the
grasp and the input point cloud.
show_best: Whether to only visualize the best predicted grasp.
Returns:
None.
"""
loader = GraspReader(data_path)
graph = build_grasp_inference_graph(num_samples=batch_size)
point_cloud_tf = graph['point_cloud_tf']
Expand Down Expand Up @@ -1778,6 +1967,17 @@ def inference_keypoint(data_path,
model_path,
batch_size=128,
num_points=1024):
"""Predicts the keypoints offline.
Args:
data_path: The point cloud data in hdf5.
model_path: The pre-trained model.
batch_size: The batch size in inference.
num_points: The points in a single frame of point cloud.
Returns:
None.
"""
loader = KeypointReader(data_path)
point_cloud_tf = tf.placeholder(
dtype=tf.float32, shape=[1, num_points, 3])
Expand Down
Loading

0 comments on commit 78ef998

Please sign in to comment.