@@ -191,6 +191,72 @@ class AnchorMode(str, Enum):
191
191
SINGLE_OBJECT = "single_object"
192
192
193
193
194
+ def get_anchor_points (
195
+ anchor_mode : AnchorMode ,
196
+ rgb ,
197
+ point_cloud ,
198
+ mask ,
199
+ task_name ,
200
+ phase ,
201
+ use_from_simulator = False ,
202
+ handle_mapping = None ,
203
+ names_to_handles = None ,
204
+ ):
205
+ if anchor_mode == AnchorMode .RAW :
206
+ return rgb , point_cloud
207
+ elif anchor_mode == AnchorMode .BACKGROUND_REMOVED :
208
+ return filter_out_names (
209
+ rgb , point_cloud , mask , handle_mapping , BACKGROUND_NAMES
210
+ )
211
+ elif anchor_mode == AnchorMode .BACKGROUND_ROBOT_REMOVED :
212
+ return filter_out_names (
213
+ rgb ,
214
+ point_cloud ,
215
+ mask ,
216
+ handle_mapping ,
217
+ BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES ,
218
+ )
219
+ elif anchor_mode == AnchorMode .SINGLE_OBJECT :
220
+ if use_from_simulator :
221
+ return get_rgb_point_cloud_by_object_names (
222
+ rgb ,
223
+ point_cloud ,
224
+ mask ,
225
+ TASK_DICT [task_name ]["phase" ][phase ]["anchor_obj_names" ],
226
+ )
227
+ else :
228
+ return get_rgb_point_cloud_by_object_handles (
229
+ rgb ,
230
+ point_cloud ,
231
+ mask ,
232
+ names_to_handles [phase ]["anchor_obj_names" ],
233
+ )
234
+ else :
235
+ raise ValueError ("Anchor mode must be one of the AnchorMode enum values." )
236
+
237
+
238
+ def get_action_points (
239
+ action_mode : ActionMode ,
240
+ rgb ,
241
+ point_cloud ,
242
+ mask ,
243
+ action_handles ,
244
+ gripper_handles ,
245
+ ):
246
+ if action_mode == ActionMode .GRIPPER_AND_OBJECT :
247
+ action_handles = action_handles + gripper_handles
248
+ elif action_mode == ActionMode .OBJECT :
249
+ pass
250
+ else :
251
+ raise ValueError ("Action mode must be one of the ActionMode enum values." )
252
+
253
+ action_rgb , action_point_cloud = get_rgb_point_cloud_by_object_handles (
254
+ rgb , point_cloud , mask , action_handles
255
+ )
256
+
257
+ return action_rgb , action_point_cloud
258
+
259
+
194
260
class RLBenchPlacementDataset (data .Dataset ):
195
261
def __init__ (
196
262
self ,
@@ -299,7 +365,7 @@ def _load_keyframes(
299
365
300
366
keyframes = [demo [ix ] for ix in keyframe_ixs ]
301
367
302
- return keyframes , demo [0 ]
368
+ return keyframes , demo [0 ] # type: ignore
303
369
304
370
# We also cache in memory, since all the transformations are the same.
305
371
# Saves a lot of time when loading the dataset, but don't have to worry
@@ -347,69 +413,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
347
413
# Find the first grasp instance
348
414
key_obs = keyframes [phase_ix ]
349
415
350
- if self .debugging :
351
- raise ValueError ("Debugging not implemented." )
352
- return {
353
- "keyframes" : keyframe_ixs ,
354
- "demo" : demo ,
355
- "initial_obs" : initial_obs ,
356
- "key_obs" : key_obs ,
357
- "init_front_rgb" : torch .from_numpy (initial_obs .front_rgb ),
358
- "key_front_rgb" : torch .from_numpy (key_obs .front_rgb ),
359
- "init_front_mask" : torch .from_numpy (
360
- initial_obs .front_mask .astype (np .int32 )
361
- ),
362
- "key_front_mask" : torch .from_numpy (key_obs .front_mask .astype (np .int32 )),
363
- "phase" : phase ,
364
- "phase_onehot" : torch .from_numpy (phase_onehot ),
365
- }
416
+ action_handles = self .names_to_handles [phase ]["action_obj_names" ]
417
+
418
+ def _select_action_vals (rgb , point_cloud , mask ):
419
+ return get_action_points (
420
+ self .action_mode ,
421
+ rgb ,
422
+ point_cloud ,
423
+ mask ,
424
+ action_handles ,
425
+ self .gripper_handles ,
426
+ )
427
+
428
+ def _select_anchor_vals (rgb , point_cloud , mask ):
429
+ return get_anchor_points (
430
+ self .anchor_mode ,
431
+ rgb ,
432
+ point_cloud ,
433
+ mask ,
434
+ self .task_name ,
435
+ phase ,
436
+ use_from_simulator = False ,
437
+ handle_mapping = self .handle_mapping ,
438
+ names_to_handles = self .names_to_handles ,
439
+ )
366
440
367
441
# Merge all the initial point clouds and masks into one.
368
442
init_rgb , init_point_cloud , init_mask = obs_to_rgb_point_cloud (initial_obs )
369
443
370
- action_handles = self .names_to_handles [phase ]["action_obj_names" ]
371
- if self .action_mode == ActionMode .GRIPPER_AND_OBJECT :
372
- action_handles = action_handles + self .gripper_handles
373
- elif self .action_mode == ActionMode .OBJECT :
374
- pass
375
- else :
376
- raise ValueError ("Action mode must be one of the ActionMode enum values." )
377
-
378
- # Split the initial point cloud and rgb into action and anchor.
379
- (
380
- init_action_rgb ,
381
- init_action_point_cloud ,
382
- ) = get_rgb_point_cloud_by_object_handles (
383
- init_rgb , init_point_cloud , init_mask , action_handles
444
+ init_action_rgb , init_action_point_cloud = _select_action_vals (
445
+ init_rgb , init_point_cloud , init_mask
384
446
)
385
447
386
- def _select_anchor_vals (rgb , point_cloud , mask ):
387
- if self .anchor_mode == AnchorMode .RAW :
388
- return rgb , point_cloud
389
- elif self .anchor_mode == AnchorMode .BACKGROUND_REMOVED :
390
- return filter_out_names (
391
- rgb , point_cloud , mask , self .handle_mapping , BACKGROUND_NAMES
392
- )
393
- elif self .anchor_mode == AnchorMode .BACKGROUND_ROBOT_REMOVED :
394
- return filter_out_names (
395
- rgb ,
396
- point_cloud ,
397
- mask ,
398
- self .handle_mapping ,
399
- BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES ,
400
- )
401
- elif self .anchor_mode == AnchorMode .SINGLE_OBJECT :
402
- return get_rgb_point_cloud_by_object_handles (
403
- rgb ,
404
- point_cloud ,
405
- mask ,
406
- self .names_to_handles [phase ]["anchor_obj_names" ],
407
- )
408
- else :
409
- raise ValueError (
410
- "Anchor mode must be one of the AnchorMode enum values."
411
- )
412
-
413
448
init_anchor_rgb , init_anchor_point_cloud = _select_anchor_vals (
414
449
init_rgb , init_point_cloud , init_mask
415
450
)
@@ -418,8 +453,8 @@ def _select_anchor_vals(rgb, point_cloud, mask):
418
453
key_rgb , key_point_cloud , key_mask = obs_to_rgb_point_cloud (key_obs )
419
454
420
455
# Split the key point cloud and rgb into action and anchor.
421
- key_action_rgb , key_action_point_cloud = get_rgb_point_cloud_by_object_handles (
422
- key_rgb , key_point_cloud , key_mask , action_handles
456
+ key_action_rgb , key_action_point_cloud = _select_action_vals (
457
+ key_rgb , key_point_cloud , key_mask
423
458
)
424
459
key_anchor_rgb , key_anchor_point_cloud = _select_anchor_vals (
425
460
key_rgb , key_point_cloud , key_mask
0 commit comments