Skip to content

Commit

Permalink
:q
Browse files Browse the repository at this point in the history
  • Loading branch information
Zengyi-Qin committed Aug 14, 2019
1 parent 446e10f commit 9a58c0e
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 106 deletions.
Empty file added keypoints/__init__.py
Empty file.
20 changes: 16 additions & 4 deletions robovat/envs/push_point_cloud_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,12 @@ def execute_action(self, action):
#
# Grasping
#
self._execute_action_grasping(action_grasp)
is_good_grasp = self._execute_action_grasping(action_grasp)

if not is_good_grasp:
return
#
# Hammering
# Pushing
#
self._execute_action_pushing(action_task)

Expand Down Expand Up @@ -261,6 +263,8 @@ def _execute_action_grasping(self, action):
self.robot.move_to_gripper_pose(prestart)

elif phase == 'start':
pre_grasp_pose = np.array(self.graspable.pose.position)

self.robot.move_to_gripper_pose(start, straight_line=True)

# Prevent problems caused by unrealistic frictions.
Expand All @@ -276,6 +280,7 @@ def _execute_action_grasping(self, action):

elif phase == 'end':
self.robot.grip(1)
post_grasp_pose = np.array(self.graspable.pose.position)

elif phase == 'postend':
postend = self.robot.end_effector.pose
Expand All @@ -295,7 +300,9 @@ def _execute_action_grasping(self, action):
rolling_friction=1000,
spinning_friction=1000)
self.table.set_dynamics(
lateral_friction=0.1)
lateral_friction=0.3)

return self._good_grasp(pre_grasp_pose, post_grasp_pose)

def _execute_action_pushing(self, action):
"""Execute the pushing action.
Expand Down Expand Up @@ -362,7 +369,7 @@ def _execute_action_pushing(self, action):
phase, num_action_steps)

elif phase == 'start':
self.grasp_cornercase = False
pass

elif phase == 'end':
pass
Expand All @@ -379,6 +386,11 @@ def _draw_path(self, action):
np.random.randint(100)))
plt.close()

def _good_grasp(self, pre, post, thres=0.02):
trans = np.linalg.norm(pre - post)
logger.debug('The tool slips {:.3f}'.format(trans))
return trans < thres

def _wait_until_ready(self, phase, num_action_steps):
ready = False
while(not ready):
Expand Down
1 change: 1 addition & 0 deletions robovat/envs/reward_fns/push_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def on_episode_start(self):

self.graspable = self.env.simulator.bodies[self.graspable_name]
self.env.timeout = False
self.env.grasp_cornercase = False

def get_reward(self):
"""Returns the reward value of the current step."""
Expand Down
105 changes: 5 additions & 100 deletions robovat/policies/push_point_cloud_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,81 +24,6 @@
nest = tf.contrib.framework.nest


class AntipodalGraspSampler(object):
"""Samples random antipodal grasps from a depth image."""

def __init__(self,
time_step_spec,
action_spec,
config,
debug=False):
debug = debug and config.DEBUG

self._time_step_spec = time_step_spec
self._action_spec = action_spec

flat_action_spec = nest.flatten(self._action_spec)
self._action_dtype = flat_action_spec[0].dtype
self._action_shape = flat_action_spec[0].shape

self._sampler = image_grasp_sampler.AntipodalDepthImageGraspSampler(
friction_coef=config.SAMPLER.FRICTION_COEF,
depth_grad_thresh=config.SAMPLER.DEPTH_GRAD_THRESH,
depth_grad_gaussian_sigma=config.SAMPLER.DEPTH_GRAD_GAUSSIAN_SIGMA,
downsample_rate=config.SAMPLER.DOWNSAMPLE_RATE,
max_rejection_samples=config.SAMPLER.MAX_REJECTION_SAMPLES,
crop=config.SAMPLER.CROP,
min_dist_from_boundary=config.SAMPLER.MIN_DIST_FROM_BOUNDARY,
min_grasp_dist=config.SAMPLER.MIN_GRASP_DIST,
angle_dist_weight=config.SAMPLER.ANGLE_DIST_WEIGHT,
depth_samples_per_grasp=config.SAMPLER.DEPTH_SAMPLES_PER_GRASP,
min_depth_offset=config.SAMPLER.MIN_DEPTH_OFFSET,
max_depth_offset=config.SAMPLER.MAX_DEPTH_OFFSET,
depth_sample_window_height=(
config.SAMPLER.DEPTH_SAMPLE_WINDOW_HEIGHT),
depth_sample_window_width=config.SAMPLER.DEPTH_SAMPLE_WINDOW_WIDTH,
gripper_width=config.GRIPPER_WIDTH,
debug=debug)

def __call__(self, time_step, num_samples, seed):
observation = nest.map_structure(lambda x: tf.squeeze(x, 0),
time_step.observation)
depth = observation['depth']
intrinsics = observation['intrinsics']
grasps = tf.py_func(
self._sampler.sample, [depth, intrinsics, num_samples], tf.float32)
grasps = tf.reshape(
grasps, [1, num_samples] + self._action_shape.as_list())
return grasps


class Grasp4DofRandomPolicy(random_tf_policy.RandomTFPolicy):
"""Sample random antipodal grasps."""

def __init__(self,
time_step_spec,
action_spec,
policy_state_spec=(),
config=None,
debug=False):
self._sampler = AntipodalGraspSampler(
time_step_spec, action_spec, config, debug=True)
self._num_samples = 1

super(Grasp4DofRandomPolicy, self).__init__(
time_step_spec,
action_spec,
policy_state_spec)

def _action(self, time_step, policy_state, seed):
actions = self._sampler(
time_step,
self._num_samples,
seed)
action = tf.squeeze(actions, 0)
return policy_step.PolicyStep(action, policy_state)


class PushPointCloudPolicy(point_cloud_policy.PointCloudPolicy):

TARGET_REGION = {
Expand All @@ -110,6 +35,10 @@ class PushPointCloudPolicy(point_cloud_policy.PointCloudPolicy):
'yaw': 0,
}

TABLE_POSE = [
[0.6, 0, 0.0],
[0, 0, 0]]

def __init__(self,
time_step_spec,
action_spec,
Expand All @@ -121,7 +50,7 @@ def __init__(self,
action_spec,
config=config)

self.table_pose = Pose(self.config.SIM.TABLE.POSE)
self.table_pose = Pose(self.TABLE_POSE)
pose = Pose.uniform(**self.TARGET_REGION)
self.target_pose = get_transform(
source=self.table_pose).transform(pose)
Expand Down Expand Up @@ -231,27 +160,3 @@ def _action(self,

return policy_step.PolicyStep(action, policy_state)


class Grasp4DofCemPolicy(cem_policy.CemPolicy):
"""4-DoF grasping policy using CEM."""

def __init__(self,
time_step_spec,
action_spec,
config=None,
debug=False):
initial_sampler = AntipodalGraspSampler(
time_step_spec, action_spec, config, debug=False)
q_network = tf.make_template(
'GQCNN',
GQCNN,
create_scope_now_=True,
time_step_spec=time_step_spec,
action_spec=action_spec)
q_network = q_network()
super(Grasp4DofCemPolicy, self).__init__(
time_step_spec=time_step_spec,
action_spec=action_spec,
q_network=q_network,
initial_sampler=initial_sampler,
config=config)
2 changes: 1 addition & 1 deletion run_push.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ python tools/run_env.py \
--policy PushPointCloudPolicy \
--policy_config configs/policies/push_point_cloud_policy.yaml \
--problem PushPointCloudProblem \
--episodic 0 --num_episodes 8192 --debug 1 \
--episodic 0 --num_episodes 8192 --debug 1 \
--output episodes/push_point_cloud \
--checkpoint keypoints/save/push/cvae_push
4 changes: 3 additions & 1 deletion tools/run_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def main():
else:
policy_config = YamlConfig(args.policy_config).as_easydict()


# Simulator.
if args.use_simulator:
simulator = Simulator(worker_id=args.worker_id,
Expand Down Expand Up @@ -190,7 +191,8 @@ def main():
policy_class = getattr(policies, args.policy)
tf_policy = policy_class(time_step_spec=tf_env.time_step_spec(),
action_spec=tf_env.action_spec(),
config=policy_config,)
config=policy_config)


py_policy = py_tf_policy.PyTFPolicy(tf_policy)

Expand Down

0 comments on commit 9a58c0e

Please sign in to comment.