@@ -273,6 +273,34 @@ def __len__(self) -> int:
273
273
else :
274
274
return self .n_demos
275
275
276
+ @staticmethod
277
+ def _load_keyframes (
278
+ dataset_root , variation , task_name , episode_index : int
279
+ ) -> List [int ]:
280
+ demo = rlbench .utils .get_stored_demos (
281
+ amount = 1 ,
282
+ image_paths = False ,
283
+ dataset_root = dataset_root ,
284
+ variation_number = variation ,
285
+ task_name = task_name ,
286
+ obs_config = ObservationConfig (
287
+ left_shoulder_camera = CameraConfig (image_size = (256 , 256 )),
288
+ right_shoulder_camera = CameraConfig (image_size = (256 , 256 )),
289
+ front_camera = CameraConfig (image_size = (256 , 256 )),
290
+ wrist_camera = CameraConfig (image_size = (256 , 256 )),
291
+ overhead_camera = CameraConfig (image_size = (256 , 256 )),
292
+ task_low_dim_state = True ,
293
+ ),
294
+ random_selection = False ,
295
+ from_episode_number = episode_index ,
296
+ )[0 ]
297
+
298
+ keyframe_ixs = keypoint_discovery_pregrasp (demo )
299
+
300
+ keyframes = [demo [ix ] for ix in keyframe_ixs ]
301
+
302
+ return keyframes , demo [0 ]
303
+
276
304
# We also cache in memory, since all the transformations are the same.
277
305
# Saves a lot of time when loading the dataset, but don't have to worry
278
306
# about logic changes after the fact.
@@ -288,29 +316,15 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
288
316
# demonstrations from disk. But this means that we'll have to be careful
289
317
# whenever we re-generate the demonstrations to delete the cache.
290
318
if self .memory is not None :
291
- get_demo_fn = self .memory .cache (rlbench . utils . get_stored_demos )
319
+ load_keyframes_fn = self .memory .cache (self . _load_keyframes )
292
320
else :
293
- get_demo_fn = rlbench . utils . get_stored_demos
321
+ load_keyframes_fn = self . _load_keyframes
294
322
295
- demo : rlbench .demo .Demo = get_demo_fn (
296
- amount = 1 ,
297
- image_paths = False ,
298
- dataset_root = self .dataset_root ,
299
- variation_number = self .variation ,
300
- task_name = self .task_name ,
301
- obs_config = ObservationConfig (
302
- left_shoulder_camera = CameraConfig (image_size = (256 , 256 )),
303
- right_shoulder_camera = CameraConfig (image_size = (256 , 256 )),
304
- front_camera = CameraConfig (image_size = (256 , 256 )),
305
- wrist_camera = CameraConfig (image_size = (256 , 256 )),
306
- overhead_camera = CameraConfig (image_size = (256 , 256 )),
307
- task_low_dim_state = True ,
308
- ),
309
- random_selection = False ,
310
- from_episode_number = self .demos [index ],
311
- )[0 ]
323
+ keyframes , first_frame = load_keyframes_fn (
324
+ self .dataset_root , self .variation , self .task_name , self .demos [index ]
325
+ )
312
326
313
- keyframes = keypoint_discovery_pregrasp ( demo )
327
+ # breakpoint( )
314
328
315
329
# Get the index of the phase into keypoints.
316
330
if self .phase == "all" :
@@ -326,16 +340,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
326
340
327
341
# Select an observation to use as the initial observation.
328
342
if self .use_first_as_init_keyframe or phase_ix == 0 :
329
- initial_obs = demo [ 0 ]
343
+ initial_obs = first_frame
330
344
else :
331
- initial_obs = demo [ keyframes [phase_ix - 1 ] ]
345
+ initial_obs = keyframes [phase_ix - 1 ]
332
346
333
347
# Find the first grasp instance
334
- key_obs = demo [ keyframes [phase_ix ] ]
348
+ key_obs = keyframes [phase_ix ]
335
349
336
350
if self .debugging :
351
+ raise ValueError ("Debugging not implemented." )
337
352
return {
338
- "keyframes" : keyframes ,
353
+ "keyframes" : keyframe_ixs ,
339
354
"demo" : demo ,
340
355
"initial_obs" : initial_obs ,
341
356
"key_obs" : key_obs ,
@@ -395,35 +410,6 @@ def _select_anchor_vals(rgb, point_cloud, mask):
395
410
"Anchor mode must be one of the AnchorMode enum values."
396
411
)
397
412
398
- # if self.anchor_mode == AnchorMode.RAW:
399
- # init_anchor_rgb = init_rgb
400
- # init_anchor_point_cloud = init_point_cloud
401
- # elif self.anchor_mode == AnchorMode.BACKGROUND_REMOVED:
402
- # init_anchor_rgb, init_anchor_point_cloud = filter_out_names(
403
- # init_rgb,
404
- # init_point_cloud,
405
- # init_mask,
406
- # self.handle_mapping,
407
- # BACKGROUND_NAMES,
408
- # )
409
- # elif self.anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
410
- # init_anchor_rgb, init_anchor_point_cloud = filter_out_names(
411
- # init_rgb,
412
- # init_point_cloud,
413
- # init_mask,
414
- # self.handle_mapping,
415
- # BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
416
- # )
417
- # elif self.anchor_mode == AnchorMode.SINGLE_OBJECT:
418
- # (
419
- # init_anchor_rgb,
420
- # init_anchor_point_cloud,
421
- # ) = get_rgb_point_cloud_by_object_handles(
422
- # init_rgb,
423
- # init_point_cloud,
424
- # init_mask,
425
- # self.names_to_handles[phase]["anchor_obj_names"],
426
- # )
427
413
init_anchor_rgb , init_anchor_point_cloud = _select_anchor_vals (
428
414
init_rgb , init_point_cloud , init_mask
429
415
)
@@ -435,34 +421,6 @@ def _select_anchor_vals(rgb, point_cloud, mask):
435
421
key_action_rgb , key_action_point_cloud = get_rgb_point_cloud_by_object_handles (
436
422
key_rgb , key_point_cloud , key_mask , action_handles
437
423
)
438
- # if self.anchor_mode == AnchorMode.RAW:
439
- # key_anchor_rgb = key_rgb
440
- # key_anchor_point_cloud = key_point_cloud
441
- # elif self.anchor_mode == AnchorMode.BACKGROUND_REMOVED:
442
- # key_anchor_rgb, key_anchor_point_cloud = filter_out_names(
443
- # key_rgb,
444
- # key_point_cloud,
445
- # key_mask,
446
- # self.handle_mapping,
447
- # BACKGROUND_NAMES,
448
- # )
449
- # elif self.anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
450
- # key_anchor_rgb, key_anchor_point_cloud = filter_out_names(
451
- # key_rgb,
452
- # key_point_cloud,
453
- # key_mask,
454
- # self.handle_mapping,
455
- # BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
456
- # )
457
- # elif self.anchor_mode == AnchorMode.SINGLE_OBJECT:
458
- # key_anchor_rgb, key_anchor_point_cloud = (
459
- # get_rgb_point_cloud_by_object_handles(
460
- # key_rgb,
461
- # key_point_cloud,
462
- # key_mask,
463
- # self.names_to_handles[phase]["anchor_obj_names"],
464
- # )
465
- # )
466
424
key_anchor_rgb , key_anchor_point_cloud = _select_anchor_vals (
467
425
key_rgb , key_point_cloud , key_mask
468
426
)
0 commit comments