@@ -195,6 +195,7 @@ def get_anchor_points(
195
195
use_from_simulator = False ,
196
196
handle_mapping = None ,
197
197
names_to_handles = None ,
198
+ gripper_in_first_phase = True ,
198
199
):
199
200
if use_from_simulator :
200
201
handle_mapping = {
@@ -212,7 +213,7 @@ def get_anchor_points(
212
213
names = BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES
213
214
214
215
# If it's the first phase, we also omit the gripper.
215
- if phase == TASK_DICT [task_name ]["phase_order" ][0 ]:
216
+ if phase == TASK_DICT [task_name ]["phase_order" ][0 ] and gripper_in_first_phase :
216
217
names += GRIPPER_OBJ_NAMES
217
218
218
219
return filter_out_names (rgb , point_cloud , mask , handle_mapping , names )
@@ -279,6 +280,7 @@ def __init__(
279
280
anchor_mode : AnchorMode = AnchorMode .SINGLE_OBJECT ,
280
281
action_mode : ActionMode = ActionMode .OBJECT ,
281
282
include_wrist_cam : bool = False ,
283
+ gripper_in_first_phase : bool = True ,
282
284
) -> None :
283
285
"""Dataset for RL-Bench placement tasks.
284
286
@@ -336,6 +338,7 @@ def leaf_fn(path, x):
336
338
raise ValueError ("Anchor mode must be one of the AnchorMode enum values." )
337
339
self .action_mode = action_mode
338
340
self .anchor_mode = anchor_mode
341
+ self .gripper_in_first_phase = gripper_in_first_phase
339
342
340
343
if cache :
341
344
self .memory = Memory (
@@ -449,6 +452,7 @@ def _select_anchor_vals(rgb, point_cloud, mask):
449
452
use_from_simulator = False ,
450
453
handle_mapping = self .handle_mapping ,
451
454
names_to_handles = self .names_to_handles ,
455
+ gripper_in_first_phase = self .gripper_in_first_phase ,
452
456
)
453
457
454
458
# Merge all the initial point clouds and masks into one.
0 commit comments