Skip to content

Commit 84f84f0

Browse files
committed
add in background too
1 parent 11722e4 commit 84f84f0

File tree

2 files changed

+98
-56
lines changed

2 files changed

+98
-56
lines changed

notebooks/explore_dset2.ipynb

+58-55
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@
374374
"outputs": [],
375375
"source": [
376376
"# Get a mapping from handle id to handle name.\n",
377-
"task_name = \"put_money_in_safe\"\n",
377+
"# task_name = \"put_money_in_safe\"\n",
378+
"task_name = \"put_knife_on_chopping_board\"\n",
378379
"handle_mapping = load_handle_mapping(\"/data/rlbench10_collisions/\", task_name, 0)\n",
379380
"rev_handle_mapping = {v: k for k, v in handle_mapping.items()}\n",
380381
"\n",
@@ -388,7 +389,7 @@
388389
"metadata": {},
389390
"outputs": [],
390391
"source": [
391-
"handle_mapping"
392+
"set(handle_mapping.keys())"
392393
]
393394
},
394395
{
@@ -563,63 +564,65 @@
563564
"# task_name = \"take_money_out_safe\"\n",
564565
"# task_name = \"take_umbrella_out_of_umbrella_stand\"\n",
565566
"\n",
566-
"n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n",
567-
"fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n",
567+
"for i in range(4):\n",
568568
"\n",
569-
"for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n",
570-
" print(f\"Phase: {phase}\")\n",
571-
" dset = RLBenchPlacementDataset(\n",
572-
" dataset_root=\"/data/rlbench10_collisions/\",\n",
573-
" task_name=task_name,\n",
574-
" demos=[0],\n",
575-
" phase=phase,\n",
576-
" debugging=False,\n",
577-
" use_first_as_init_keyframe=False,\n",
578-
" anchor_mode=\"background_robot_removed\",\n",
579-
" action_mode=\"gripper_and_object\",\n",
580-
" include_wrist_cam=True,\n",
581-
" gripper_in_first_phase=True,\n",
582-
" )\n",
569+
" n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n",
570+
" fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n",
583571
"\n",
584-
" data = dset[0]\n",
585-
"\n",
586-
" # Plot segmentation with segmentation_fig\n",
587-
"\n",
588-
" print(list(data.keys()))\n",
589-
"\n",
590-
" anchor_pc = data[\"init_anchor_pc\"]\n",
591-
" # Randomly downsample the anchor point cloud.\n",
592-
" n_pts = anchor_pc.shape[0]\n",
593-
" if n_pts > 1000:\n",
594-
" anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n",
595-
"\n",
596-
" points = torch.cat(\n",
597-
" [\n",
598-
" data[\"init_action_pc\"],\n",
599-
" anchor_pc,\n",
600-
" data[\"key_action_pc\"],\n",
601-
" ]\n",
602-
" )\n",
603-
" print(points.shape)\n",
604-
" seg = torch.cat(\n",
605-
" [\n",
606-
" torch.zeros(data[\"init_action_pc\"].shape[0]),\n",
607-
" torch.ones(anchor_pc.shape[0]),\n",
608-
" 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n",
609-
" ]\n",
610-
" )\n",
572+
" for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n",
573+
" print(f\"Phase: {phase}\")\n",
574+
" dset = RLBenchPlacementDataset(\n",
575+
" dataset_root=\"/data/rlbench10_collisions/\",\n",
576+
" task_name=task_name,\n",
577+
" demos=range(100),\n",
578+
" phase=phase,\n",
579+
" debugging=False,\n",
580+
" use_first_as_init_keyframe=False,\n",
581+
" anchor_mode=\"background_robot_removed\",\n",
582+
" action_mode=\"gripper_and_object\",\n",
583+
" include_wrist_cam=True,\n",
584+
" gripper_in_first_phase=True,\n",
585+
" )\n",
611586
"\n",
612-
" fig = segmentation_fig_rc(\n",
613-
" points,\n",
614-
" seg.int(),\n",
615-
" labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n",
616-
" fig=fig,\n",
617-
" row=1,\n",
618-
" column=ix+1,\n",
619-
" n_col=n_phases,\n",
620-
" )\n",
587+
" data = dset[i]\n",
621588
"\n",
622-
"fig.show()\n",
589+
" # Plot segmentation with segmentation_fig\n",
590+
"\n",
591+
" print(list(data.keys()))\n",
592+
"\n",
593+
" anchor_pc = data[\"init_anchor_pc\"]\n",
594+
" # Randomly downsample the anchor point cloud.\n",
595+
" n_pts = anchor_pc.shape[0]\n",
596+
" if n_pts > 1000:\n",
597+
" anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n",
598+
"\n",
599+
" points = torch.cat(\n",
600+
" [\n",
601+
" data[\"init_action_pc\"],\n",
602+
" anchor_pc,\n",
603+
" data[\"key_action_pc\"],\n",
604+
" ]\n",
605+
" )\n",
606+
" print(points.shape)\n",
607+
" seg = torch.cat(\n",
608+
" [\n",
609+
" torch.zeros(data[\"init_action_pc\"].shape[0]),\n",
610+
" torch.ones(anchor_pc.shape[0]),\n",
611+
" 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n",
612+
" ]\n",
613+
" )\n",
614+
"\n",
615+
" fig = segmentation_fig_rc(\n",
616+
" points,\n",
617+
" seg.int(),\n",
618+
" labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n",
619+
" fig=fig,\n",
620+
" row=1,\n",
621+
" column=ix+1,\n",
622+
" n_col=n_phases,\n",
623+
" )\n",
624+
"\n",
625+
" fig.show()\n",
623626
" "
624627
]
625628
},

src/rpad/rlbench_utils/placement_dataset.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,51 @@ def load_state_pos_dict(
140140

141141

142142
BACKGROUND_NAMES = [
143+
"DefaultCamera",
144+
"DefaultLightA",
145+
"DefaultLightB",
146+
"DefaultLightC",
147+
"DefaultLightD",
148+
"DefaultLights",
149+
"DefaultNXViewCamera",
150+
"DefaultNYViewCamera",
151+
"DefaultNZViewCamera",
152+
"DefaultXViewCamera",
153+
"DefaultYViewCamera",
154+
"DefaultZViewCamera",
155+
"Dummy",
156+
"Floor",
157+
"FloorAnchor",
158+
"ResizableFloor_5_25",
159+
"ResizableFloor_5_25_element",
143160
"ResizableFloor_5_25_visibleElement",
161+
"Roof",
144162
"Wall1",
145163
"Wall2",
146164
"Wall3",
147165
"Wall4",
148-
"Roof",
166+
"XYZCameraProxy",
167+
"boundary",
168+
"cam_cinematic_base",
169+
"cam_cinematic_placeholder",
170+
"cam_front",
171+
"cam_front_mask",
172+
"cam_over_shoulder_left",
173+
"cam_over_shoulder_left_mask",
174+
"cam_over_shoulder_right",
175+
"cam_over_shoulder_right_mask",
176+
"cam_overhead",
177+
"cam_overhead_mask",
178+
"cam_wrist",
179+
"cam_wrist_mask",
180+
"diningTable",
149181
"diningTable_visible",
182+
"remoteApi",
183+
"success",
184+
"waypoint0",
185+
"waypoint1",
186+
"waypoint2",
187+
"waypoint3",
150188
"workspace",
151189
]
152190

@@ -165,6 +203,7 @@ def load_state_pos_dict(
165203
def filter_out_names(rgb, point_cloud, mask, handlemapping, names=BACKGROUND_NAMES):
166204
# Get the indices of the background.
167205
background_handles = [handlemapping[name] for name in names]
206+
background_handles.append(65535) # It's -1, cast as uint16.
168207
background_indices = np.isin(mask, background_handles).reshape((-1))
169208

170209
# Get the indices of the foreground.

0 commit comments

Comments
 (0)