@@ -39,7 +39,7 @@ def get_rgb_point_cloud_by_object_names(rgb, point_cloud, seg_labels, names):
39
39
return get_rgb_point_cloud_by_object_handles (rgb , point_cloud , seg_labels , handles )
40
40
41
41
42
- def obs_to_rgb_point_cloud (obs ):
42
+ def obs_to_rgb_point_cloud (obs , include_wrist_cam = False ):
43
43
# Get the overhead, left, front, and right RGB images.
44
44
overhead_rgb = obs .overhead_rgb
45
45
left_rgb = obs .left_shoulder_rgb
@@ -84,31 +84,25 @@ def obs_to_rgb_point_cloud(obs):
84
84
85
85
# Stack the RGB and point cloud images together.
86
86
rgb = np .vstack (
87
- (
88
- overhead_rgb ,
89
- left_rgb ,
90
- right_rgb ,
91
- front_rgb ,
92
- # wrist_rgb,
93
- )
87
+ (overhead_rgb , left_rgb , right_rgb , front_rgb )
88
+ if not include_wrist_cam
89
+ else (overhead_rgb , left_rgb , right_rgb , front_rgb , wrist_rgb )
94
90
)
95
91
point_cloud = np .vstack (
96
- (
92
+ (overhead_point_cloud , left_point_cloud , right_point_cloud , front_point_cloud )
93
+ if not include_wrist_cam
94
+ else (
97
95
overhead_point_cloud ,
98
96
left_point_cloud ,
99
97
right_point_cloud ,
100
98
front_point_cloud ,
101
- # wrist_point_cloud,
99
+ wrist_point_cloud ,
102
100
)
103
101
)
104
102
mask = np .vstack (
105
- (
106
- overhead_mask ,
107
- left_mask ,
108
- right_mask ,
109
- front_mask ,
110
- # wrist_mask,
111
- )
103
+ (overhead_mask , left_mask , right_mask , front_mask )
104
+ if not include_wrist_cam
105
+ else (overhead_mask , left_mask , right_mask , front_mask , wrist_mask )
112
106
)
113
107
114
108
return rgb , point_cloud , mask
@@ -284,6 +278,7 @@ def __init__(
284
278
debugging : bool = False ,
285
279
anchor_mode : AnchorMode = AnchorMode .SINGLE_OBJECT ,
286
280
action_mode : ActionMode = ActionMode .OBJECT ,
281
+ include_wrist_cam : bool = False ,
287
282
) -> None :
288
283
"""Dataset for RL-Bench placement tasks.
289
284
@@ -309,6 +304,7 @@ def __init__(
309
304
self .variation = 0
310
305
self .debugging = debugging
311
306
self .use_first_as_init_keyframe = use_first_as_init_keyframe
307
+ self .include_wrist_cam = include_wrist_cam
312
308
313
309
if self .task_name not in TASK_DICT :
314
310
raise ValueError (f"Task name { self .task_name } not supported." )
@@ -456,7 +452,9 @@ def _select_anchor_vals(rgb, point_cloud, mask):
456
452
)
457
453
458
454
# Merge all the initial point clouds and masks into one.
459
- init_rgb , init_point_cloud , init_mask = obs_to_rgb_point_cloud (initial_obs )
455
+ init_rgb , init_point_cloud , init_mask = obs_to_rgb_point_cloud (
456
+ initial_obs , self .include_wrist_cam
457
+ )
460
458
461
459
init_action_rgb , init_action_point_cloud = _select_action_vals (
462
460
init_rgb , init_point_cloud , init_mask
@@ -467,7 +465,9 @@ def _select_anchor_vals(rgb, point_cloud, mask):
467
465
)
468
466
469
467
# Merge all the key point clouds and masks into one.
470
- key_rgb , key_point_cloud , key_mask = obs_to_rgb_point_cloud (key_obs )
468
+ key_rgb , key_point_cloud , key_mask = obs_to_rgb_point_cloud (
469
+ key_obs , self .include_wrist_cam
470
+ )
471
471
472
472
# Split the key point cloud and rgb into action and anchor.
473
473
key_action_rgb , key_action_point_cloud = _select_action_vals (
0 commit comments