Skip to content

Commit d2ad10c

Browse files
committed
made some changes
1 parent 155a36c commit d2ad10c

File tree

4 files changed

+250
-178
lines changed

4 files changed

+250
-178
lines changed

notebooks/explore_dset2.ipynb

+40-7
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@
6767
"import rlbench\n",
6868
"from rlbench.observation_config import CameraConfig, ObservationConfig\n",
6969
"\n",
70-
"demo = rlbench.utils.get_stored_demos(\n",
71-
" amount=1,\n",
70+
"demos = rlbench.utils.get_stored_demos(\n",
71+
" amount=10,\n",
7272
" image_paths=False,\n",
73-
" dataset_root=\"/data/rlbench10\",\n",
73+
" dataset_root=\"/data/rlbench10_collisions\",\n",
7474
" variation_number=0,\n",
7575
" # task_name=\"slide_block_to_target\",\n",
7676
" # task_name=\"reach_target\",\n",
77-
" task_name=\"put_money_in_safe\",\n",
77+
" task_name=\"stack_wine\",\n",
7878
" obs_config=ObservationConfig(\n",
7979
" left_shoulder_camera=CameraConfig(image_size=(256, 256)),\n",
8080
" right_shoulder_camera=CameraConfig(image_size=(256, 256)),\n",
@@ -85,7 +85,32 @@
8585
" ),\n",
8686
" random_selection=False,\n",
8787
" from_episode_number=0,\n",
88-
")[0]"
88+
")"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"\n",
98+
"for i in range(len(demo)):\n",
99+
" print(demo[i].ignore_collisions)"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"from rpad.rlbench_utils.keyframing_pregrasp import keypoint_discovery_pregrasp\n",
109+
"\n",
110+
"\n",
111+
"keyframe_ixs = keypoint_discovery_pregrasp(demo)\n",
112+
"\n",
113+
"keyframes = [demo[ix] for ix in keyframe_ixs]"
89114
]
90115
},
91116
{
@@ -94,7 +119,7 @@
94119
"metadata": {},
95120
"outputs": [],
96121
"source": [
97-
"dir(demo[0])"
122+
"import numpy as np"
98123
]
99124
},
100125
{
@@ -103,7 +128,15 @@
103128
"metadata": {},
104129
"outputs": [],
105130
"source": [
106-
"demo[0].task_low_dim_state.shape"
131+
"all_colls = []\n",
132+
"for demo in demos:\n",
133+
" keyframe_ixs = keypoint_discovery_pregrasp(demo)\n",
134+
" keyframes = [demo[ix] for ix in keyframe_ixs]\n",
135+
" colls = [keyframe.ignore_collisions for keyframe in keyframes]\n",
136+
" all_colls.append(colls)\n",
137+
"\n",
138+
"all_colls = np.array(all_colls)\n",
139+
"all_colls"
107140
]
108141
},
109142
{

src/rpad/rlbench_utils/keyframing.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Borrowed the following keyframing code from:
22
https://github.com/zhouxian/act3d-chained-diffuser/blob/main/online_evaluation/utils_with_rlbench.py
33
"""
4+
5+
import logging
46
from typing import List
57

68
import numpy as np
@@ -51,3 +53,24 @@ def keypoint_discovery(demo: Demo, stopping_delta=0.1) -> List[int]:
5153
episode_keypoints.pop(-2)
5254

5355
return episode_keypoints
56+
57+
58+
def keypoint_discovery_original(demo: Demo, stopping_delta=0.1) -> List[int]:
59+
episode_keypoints = []
60+
prev_gripper_open = demo[0].gripper_open
61+
stopped_buffer = 0
62+
for i, obs in enumerate(demo):
63+
stopped = _is_stopped(demo, i, obs, stopped_buffer, stopping_delta)
64+
stopped_buffer = 4 if stopped else stopped_buffer - 1
65+
# If change in gripper, or end of episode.
66+
last = i == (len(demo) - 1)
67+
if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped):
68+
episode_keypoints.append(i)
69+
prev_gripper_open = obs.gripper_open
70+
if (
71+
len(episode_keypoints) > 1
72+
and (episode_keypoints[-1] - 1) == episode_keypoints[-2]
73+
):
74+
episode_keypoints.pop(-2)
75+
logging.debug("Found %d keypoints." % len(episode_keypoints), episode_keypoints)
76+
return episode_keypoints

src/rpad/rlbench_utils/placement_dataset.py

+39-81
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,34 @@ def __len__(self) -> int:
273273
else:
274274
return self.n_demos
275275

276+
@staticmethod
277+
def _load_keyframes(
278+
dataset_root, variation, task_name, episode_index: int
279+
) -> List[int]:
280+
demo = rlbench.utils.get_stored_demos(
281+
amount=1,
282+
image_paths=False,
283+
dataset_root=dataset_root,
284+
variation_number=variation,
285+
task_name=task_name,
286+
obs_config=ObservationConfig(
287+
left_shoulder_camera=CameraConfig(image_size=(256, 256)),
288+
right_shoulder_camera=CameraConfig(image_size=(256, 256)),
289+
front_camera=CameraConfig(image_size=(256, 256)),
290+
wrist_camera=CameraConfig(image_size=(256, 256)),
291+
overhead_camera=CameraConfig(image_size=(256, 256)),
292+
task_low_dim_state=True,
293+
),
294+
random_selection=False,
295+
from_episode_number=episode_index,
296+
)[0]
297+
298+
keyframe_ixs = keypoint_discovery_pregrasp(demo)
299+
300+
keyframes = [demo[ix] for ix in keyframe_ixs]
301+
302+
return keyframes, demo[0]
303+
276304
# We also cache in memory, since all the transformations are the same.
277305
# Saves a lot of time when loading the dataset, but don't have to worry
278306
# about logic changes after the fact.
@@ -288,29 +316,15 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
288316
# demonstrations from disk. But this means that we'll have to be careful
289317
# whenever we re-generate the demonstrations to delete the cache.
290318
if self.memory is not None:
291-
get_demo_fn = self.memory.cache(rlbench.utils.get_stored_demos)
319+
load_keyframes_fn = self.memory.cache(self._load_keyframes)
292320
else:
293-
get_demo_fn = rlbench.utils.get_stored_demos
321+
load_keyframes_fn = self._load_keyframes
294322

295-
demo: rlbench.demo.Demo = get_demo_fn(
296-
amount=1,
297-
image_paths=False,
298-
dataset_root=self.dataset_root,
299-
variation_number=self.variation,
300-
task_name=self.task_name,
301-
obs_config=ObservationConfig(
302-
left_shoulder_camera=CameraConfig(image_size=(256, 256)),
303-
right_shoulder_camera=CameraConfig(image_size=(256, 256)),
304-
front_camera=CameraConfig(image_size=(256, 256)),
305-
wrist_camera=CameraConfig(image_size=(256, 256)),
306-
overhead_camera=CameraConfig(image_size=(256, 256)),
307-
task_low_dim_state=True,
308-
),
309-
random_selection=False,
310-
from_episode_number=self.demos[index],
311-
)[0]
323+
keyframes, first_frame = load_keyframes_fn(
324+
self.dataset_root, self.variation, self.task_name, self.demos[index]
325+
)
312326

313-
keyframes = keypoint_discovery_pregrasp(demo)
327+
# breakpoint()
314328

315329
# Get the index of the phase into keypoints.
316330
if self.phase == "all":
@@ -326,16 +340,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
326340

327341
# Select an observation to use as the initial observation.
328342
if self.use_first_as_init_keyframe or phase_ix == 0:
329-
initial_obs = demo[0]
343+
initial_obs = first_frame
330344
else:
331-
initial_obs = demo[keyframes[phase_ix - 1]]
345+
initial_obs = keyframes[phase_ix - 1]
332346

333347
# Find the first grasp instance
334-
key_obs = demo[keyframes[phase_ix]]
348+
key_obs = keyframes[phase_ix]
335349

336350
if self.debugging:
351+
raise ValueError("Debugging not implemented.")
337352
return {
338-
"keyframes": keyframes,
353+
"keyframes": keyframe_ixs,
339354
"demo": demo,
340355
"initial_obs": initial_obs,
341356
"key_obs": key_obs,
@@ -395,35 +410,6 @@ def _select_anchor_vals(rgb, point_cloud, mask):
395410
"Anchor mode must be one of the AnchorMode enum values."
396411
)
397412

398-
# if self.anchor_mode == AnchorMode.RAW:
399-
# init_anchor_rgb = init_rgb
400-
# init_anchor_point_cloud = init_point_cloud
401-
# elif self.anchor_mode == AnchorMode.BACKGROUND_REMOVED:
402-
# init_anchor_rgb, init_anchor_point_cloud = filter_out_names(
403-
# init_rgb,
404-
# init_point_cloud,
405-
# init_mask,
406-
# self.handle_mapping,
407-
# BACKGROUND_NAMES,
408-
# )
409-
# elif self.anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
410-
# init_anchor_rgb, init_anchor_point_cloud = filter_out_names(
411-
# init_rgb,
412-
# init_point_cloud,
413-
# init_mask,
414-
# self.handle_mapping,
415-
# BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
416-
# )
417-
# elif self.anchor_mode == AnchorMode.SINGLE_OBJECT:
418-
# (
419-
# init_anchor_rgb,
420-
# init_anchor_point_cloud,
421-
# ) = get_rgb_point_cloud_by_object_handles(
422-
# init_rgb,
423-
# init_point_cloud,
424-
# init_mask,
425-
# self.names_to_handles[phase]["anchor_obj_names"],
426-
# )
427413
init_anchor_rgb, init_anchor_point_cloud = _select_anchor_vals(
428414
init_rgb, init_point_cloud, init_mask
429415
)
@@ -435,34 +421,6 @@ def _select_anchor_vals(rgb, point_cloud, mask):
435421
key_action_rgb, key_action_point_cloud = get_rgb_point_cloud_by_object_handles(
436422
key_rgb, key_point_cloud, key_mask, action_handles
437423
)
438-
# if self.anchor_mode == AnchorMode.RAW:
439-
# key_anchor_rgb = key_rgb
440-
# key_anchor_point_cloud = key_point_cloud
441-
# elif self.anchor_mode == AnchorMode.BACKGROUND_REMOVED:
442-
# key_anchor_rgb, key_anchor_point_cloud = filter_out_names(
443-
# key_rgb,
444-
# key_point_cloud,
445-
# key_mask,
446-
# self.handle_mapping,
447-
# BACKGROUND_NAMES,
448-
# )
449-
# elif self.anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
450-
# key_anchor_rgb, key_anchor_point_cloud = filter_out_names(
451-
# key_rgb,
452-
# key_point_cloud,
453-
# key_mask,
454-
# self.handle_mapping,
455-
# BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
456-
# )
457-
# elif self.anchor_mode == AnchorMode.SINGLE_OBJECT:
458-
# key_anchor_rgb, key_anchor_point_cloud = (
459-
# get_rgb_point_cloud_by_object_handles(
460-
# key_rgb,
461-
# key_point_cloud,
462-
# key_mask,
463-
# self.names_to_handles[phase]["anchor_obj_names"],
464-
# )
465-
# )
466424
key_anchor_rgb, key_anchor_point_cloud = _select_anchor_vals(
467425
key_rgb, key_point_cloud, key_mask
468426
)

0 commit comments

Comments
 (0)