Skip to content

Commit a791ba1

Browse files
committed
added wrist optionally
1 parent 46bba6e commit a791ba1

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

src/rpad/rlbench_utils/placement_dataset.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_rgb_point_cloud_by_object_names(rgb, point_cloud, seg_labels, names):
3939
return get_rgb_point_cloud_by_object_handles(rgb, point_cloud, seg_labels, handles)
4040

4141

42-
def obs_to_rgb_point_cloud(obs):
42+
def obs_to_rgb_point_cloud(obs, include_wrist_cam=False):
4343
# Get the overhead, left, front, and right RGB images.
4444
overhead_rgb = obs.overhead_rgb
4545
left_rgb = obs.left_shoulder_rgb
@@ -84,31 +84,25 @@ def obs_to_rgb_point_cloud(obs):
8484

8585
# Stack the RGB and point cloud images together.
8686
rgb = np.vstack(
87-
(
88-
overhead_rgb,
89-
left_rgb,
90-
right_rgb,
91-
front_rgb,
92-
# wrist_rgb,
93-
)
87+
(overhead_rgb, left_rgb, right_rgb, front_rgb)
88+
if not include_wrist_cam
89+
else (overhead_rgb, left_rgb, right_rgb, front_rgb, wrist_rgb)
9490
)
9591
point_cloud = np.vstack(
96-
(
92+
(overhead_point_cloud, left_point_cloud, right_point_cloud, front_point_cloud)
93+
if not include_wrist_cam
94+
else (
9795
overhead_point_cloud,
9896
left_point_cloud,
9997
right_point_cloud,
10098
front_point_cloud,
101-
# wrist_point_cloud,
99+
wrist_point_cloud,
102100
)
103101
)
104102
mask = np.vstack(
105-
(
106-
overhead_mask,
107-
left_mask,
108-
right_mask,
109-
front_mask,
110-
# wrist_mask,
111-
)
103+
(overhead_mask, left_mask, right_mask, front_mask)
104+
if not include_wrist_cam
105+
else (overhead_mask, left_mask, right_mask, front_mask, wrist_mask)
112106
)
113107

114108
return rgb, point_cloud, mask
@@ -284,6 +278,7 @@ def __init__(
284278
debugging: bool = False,
285279
anchor_mode: AnchorMode = AnchorMode.SINGLE_OBJECT,
286280
action_mode: ActionMode = ActionMode.OBJECT,
281+
include_wrist_cam: bool = False,
287282
) -> None:
288283
"""Dataset for RL-Bench placement tasks.
289284
@@ -309,6 +304,7 @@ def __init__(
309304
self.variation = 0
310305
self.debugging = debugging
311306
self.use_first_as_init_keyframe = use_first_as_init_keyframe
307+
self.include_wrist_cam = include_wrist_cam
312308

313309
if self.task_name not in TASK_DICT:
314310
raise ValueError(f"Task name {self.task_name} not supported.")
@@ -456,7 +452,9 @@ def _select_anchor_vals(rgb, point_cloud, mask):
456452
)
457453

458454
# Merge all the initial point clouds and masks into one.
459-
init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(initial_obs)
455+
init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(
456+
initial_obs, self.include_wrist_cam
457+
)
460458

461459
init_action_rgb, init_action_point_cloud = _select_action_vals(
462460
init_rgb, init_point_cloud, init_mask
@@ -467,7 +465,9 @@ def _select_anchor_vals(rgb, point_cloud, mask):
467465
)
468466

469467
# Merge all the key point clouds and masks into one.
470-
key_rgb, key_point_cloud, key_mask = obs_to_rgb_point_cloud(key_obs)
468+
key_rgb, key_point_cloud, key_mask = obs_to_rgb_point_cloud(
469+
key_obs, self.include_wrist_cam
470+
)
471471

472472
# Split the key point cloud and rgb into action and anchor.
473473
key_action_rgb, key_action_point_cloud = _select_action_vals(

0 commit comments

Comments
 (0)