Skip to content

Commit 155a36c

Browse files
committed
additiona deteails
1 parent 43338a9 commit 155a36c

File tree

3 files changed

+404
-44
lines changed

3 files changed

+404
-44
lines changed

notebooks/explore_dset2.ipynb

+238-15
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
"\n",
2222
"from rpad.visualize_3d.plots import segmentation_fig\n",
2323
"import torch\n",
24-
"import matplotlib.pyplot as plt"
24+
"import matplotlib.pyplot as plt\n",
25+
"from plotly.subplots import make_subplots\n"
2526
]
2627
},
2728
{
@@ -39,12 +40,81 @@
3940
" # task_name=\"put_toilet_roll_on_stand\",\n",
4041
" # task_name=\"place_hanger_on_rack\",\n",
4142
" # task_name=\"solve_puzzle\",\n",
42-
" task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
43+
" # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
44+
" task_name=\"slide_block_to_target\",\n",
4345
" demos=range(10),\n",
4446
" phase=\"all\",\n",
4547
")"
4648
]
4749
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"low_dim_state_dict = load_state_pos_dict(\n",
57+
" \"/data/rlbench10\", \"slide_block_to_target\", 0, 0\n",
58+
")"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"import rlbench\n",
68+
"from rlbench.observation_config import CameraConfig, ObservationConfig\n",
69+
"\n",
70+
"demo = rlbench.utils.get_stored_demos(\n",
71+
" amount=1,\n",
72+
" image_paths=False,\n",
73+
" dataset_root=\"/data/rlbench10\",\n",
74+
" variation_number=0,\n",
75+
" # task_name=\"slide_block_to_target\",\n",
76+
" # task_name=\"reach_target\",\n",
77+
" task_name=\"put_money_in_safe\",\n",
78+
" obs_config=ObservationConfig(\n",
79+
" left_shoulder_camera=CameraConfig(image_size=(256, 256)),\n",
80+
" right_shoulder_camera=CameraConfig(image_size=(256, 256)),\n",
81+
" front_camera=CameraConfig(image_size=(256, 256)),\n",
82+
" wrist_camera=CameraConfig(image_size=(256, 256)),\n",
83+
" overhead_camera=CameraConfig(image_size=(256, 256)),\n",
84+
" task_low_dim_state=True,\n",
85+
" ),\n",
86+
" random_selection=False,\n",
87+
" from_episode_number=0,\n",
88+
")[0]"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"dir(demo[0])"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"demo[0].task_low_dim_state.shape"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"print(list(low_dim_state_dict.items()))"
116+
]
117+
},
48118
{
49119
"cell_type": "markdown",
50120
"metadata": {},
@@ -67,7 +137,8 @@
67137
" # task_name=\"put_toilet_roll_on_stand\",\n",
68138
" # task_name=\"place_hanger_on_rack\",\n",
69139
" # task_name=\"solve_puzzle\",\n",
70-
" task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
140+
" # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
141+
" task_name=\"slide_block_to_target\",\n",
71142
" demos=[0],\n",
72143
" phase=\"all\",\n",
73144
" use_first_as_init_keyframe=False,\n",
@@ -121,6 +192,27 @@
121192
" axes[i // 5, i % 5].set_title(f\"Demo {i}\")"
122193
]
123194
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": null,
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"RLBENCH10_TASKS = [\n",
202+
" \"pick_and_lift\",\n",
203+
" \"put_knife_on_chopping_board\",\n",
204+
" \"take_money_out_safe\",\n",
205+
" \"pick_up_cup\",\n",
206+
" \"put_money_in_safe\",\n",
207+
" \"slide_block_to_target\",\n",
208+
" \"take_umbrella_out_of_umbrella_stand\",\n",
209+
" \"push_button\",\n",
210+
" \"reach_target\",\n",
211+
" \"stack_wine\",\n",
212+
" \n",
213+
"]"
214+
]
215+
},
124216
{
125217
"cell_type": "markdown",
126218
"metadata": {},
@@ -271,6 +363,7 @@
271363
"metadata": {},
272364
"outputs": [],
273365
"source": [
366+
"task_name = \"put_money_in_safe\"\n",
274367
"state_pos_dict = load_state_pos_dict(\"/data/rlbench10/\", task_name, 0, 0)\n",
275368
"state_pos_dict"
276369
]
@@ -303,13 +396,14 @@
303396
" fig: Optional[go.Figure] = None,\n",
304397
" row: int = 1,\n",
305398
" column: int = 1,\n",
399+
" n_col: int = 5,\n",
306400
"):\n",
307401
" \"\"\"Creates a segmentation figure.\"\"\"\n",
308402
" # Create a figure.\n",
309403
" if fig is None:\n",
310404
" fig = go.Figure()\n",
311405
"\n",
312-
" scene_num = (row-1) * 5 + column\n",
406+
" scene_num = (row-1) * n_col + column\n",
313407
"\n",
314408
" fig.add_traces(_segmentation_traces(data, labels, labelmap, f\"scene{scene_num}\", sizes), rows=row, cols=column)\n",
315409
"\n",
@@ -333,6 +427,15 @@
333427
"data.keys()"
334428
]
335429
},
430+
{
431+
"cell_type": "code",
432+
"execution_count": null,
433+
"metadata": {},
434+
"outputs": [],
435+
"source": [
436+
"from plotly.subplots import make_subplots\n"
437+
]
438+
},
336439
{
337440
"cell_type": "code",
338441
"execution_count": null,
@@ -341,7 +444,6 @@
341444
"source": [
342445
"# For each phase, plot the segmentation.\n",
343446
"\n",
344-
"from plotly.subplots import make_subplots\n",
345447
"\n",
346448
"phase = list(TASK_DICT[\"take_umbrella_out_of_umbrella_stand\"][\"phase\"].keys())[0]\n",
347449
"\n",
@@ -386,29 +488,52 @@
386488
"execution_count": null,
387489
"metadata": {},
388490
"outputs": [],
389-
"source": []
491+
"source": [
492+
"demo[0].task_low_dim_state"
493+
]
390494
},
391495
{
392496
"cell_type": "code",
393497
"execution_count": null,
394498
"metadata": {},
395499
"outputs": [],
396500
"source": [
397-
"# task_name = \"stack_wine\"\n",
398-
"# task_name = \"reach_target\"\n",
501+
"# - \"pick_and_lift\",\n",
502+
"# - \"pick_up_cup\",\n",
503+
"# - \"put_knife_on_chopping_board\",\n",
504+
"# - \"put_money_in_safe\",\n",
505+
"# - \"push_button\",\n",
506+
"# - \"reach_target\",\n",
507+
"# - \"slide_block_to_target\",\n",
508+
"# - \"stack_wine\",\n",
509+
"# - \"take_money_out_safe\",\n",
510+
"# - \"take_umbrella_out_of_umbrella_stand\",\n",
511+
"\n",
512+
"# task_name = \"pick_and_lift\"\n",
513+
"# task_name = \"pick_up_cup\"\n",
514+
"# task_name = \"put_knife_on_chopping_board\"\n",
399515
"# task_name = \"put_money_in_safe\"\n",
516+
"# task_name = \"push_button\"\n",
517+
"# task_name = \"reach_target\"\n",
518+
"# task_name = \"slide_block_to_target\"\n",
519+
"task_name = \"stack_wine\"\n",
400520
"# task_name = \"take_money_out_safe\"\n",
401-
"task_name = \"put_knife_on_chopping_board\"\n",
521+
"# task_name = \"take_umbrella_out_of_umbrella_stand\"\n",
522+
"\n",
523+
"n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n",
524+
"fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n",
402525
"\n",
403-
"for phase in TASK_DICT[task_name][\"phase_order\"]:\n",
526+
"for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n",
404527
" print(f\"Phase: {phase}\")\n",
405528
" dset = RLBenchPlacementDataset(\n",
406529
" dataset_root=\"/data/rlbench10/\",\n",
407530
" task_name=task_name,\n",
408-
" n_demos=1,\n",
531+
" demos=[0],\n",
409532
" phase=phase,\n",
410533
" debugging=False,\n",
411534
" use_first_as_init_keyframe=False,\n",
535+
" anchor_mode=\"background_robot_removed\",\n",
536+
" action_mode=\"gripper_and_object\",\n",
412537
" )\n",
413538
"\n",
414539
" data = dset[0]\n",
@@ -417,28 +542,39 @@
417542
"\n",
418543
" print(list(data.keys()))\n",
419544
"\n",
545+
" anchor_pc = data[\"init_anchor_pc\"]\n",
546+
" # Randomly downsample the anchor point cloud.\n",
547+
" n_pts = anchor_pc.shape[0]\n",
548+
" if n_pts > 1000:\n",
549+
" anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n",
550+
"\n",
420551
" points = torch.cat(\n",
421552
" [\n",
422553
" data[\"init_action_pc\"],\n",
423-
" data[\"init_anchor_pc\"],\n",
554+
" anchor_pc,\n",
424555
" data[\"key_action_pc\"],\n",
425556
" ]\n",
426557
" )\n",
427558
" print(points.shape)\n",
428559
" seg = torch.cat(\n",
429560
" [\n",
430561
" torch.zeros(data[\"init_action_pc\"].shape[0]),\n",
431-
" torch.ones(data[\"init_anchor_pc\"].shape[0]),\n",
562+
" torch.ones(anchor_pc.shape[0]),\n",
432563
" 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n",
433564
" ]\n",
434565
" )\n",
435-
" fig = segmentation_fig(\n",
566+
"\n",
567+
" fig = segmentation_fig_rc(\n",
436568
" points,\n",
437569
" seg.int(),\n",
438570
" labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n",
571+
" fig=fig,\n",
572+
" row=1,\n",
573+
" column=ix+1,\n",
574+
" n_col=n_phases,\n",
439575
" )\n",
440-
" fig.show()\n",
441576
"\n",
577+
"fig.show()\n",
442578
" "
443579
]
444580
},
@@ -451,6 +587,93 @@
451587
"!ls /data/rlbench10/put_money_in_safe/variation0/episodes/episode0"
452588
]
453589
},
590+
{
591+
"cell_type": "code",
592+
"execution_count": null,
593+
"metadata": {},
594+
"outputs": [],
595+
"source": [
596+
"# Debugging dataset.\n",
597+
"\n"
598+
]
599+
},
600+
{
601+
"cell_type": "code",
602+
"execution_count": null,
603+
"metadata": {},
604+
"outputs": [],
605+
"source": []
606+
},
607+
{
608+
"cell_type": "code",
609+
"execution_count": null,
610+
"metadata": {},
611+
"outputs": [],
612+
"source": [
613+
"np.unique(data[\"init_front_mask\"])"
614+
]
615+
},
616+
{
617+
"cell_type": "code",
618+
"execution_count": null,
619+
"metadata": {},
620+
"outputs": [],
621+
"source": [
622+
"\n",
623+
"\n",
624+
"from rpad.rlbench_utils.placement_dataset import obs_to_rgb_point_cloud\n",
625+
"\n",
626+
"\n"
627+
]
628+
},
629+
{
630+
"cell_type": "code",
631+
"execution_count": null,
632+
"metadata": {},
633+
"outputs": [],
634+
"source": []
635+
},
636+
{
637+
"cell_type": "code",
638+
"execution_count": null,
639+
"metadata": {},
640+
"outputs": [],
641+
"source": [
642+
"\n",
643+
"unique_elements = set()\n",
644+
"\n",
645+
"\n",
646+
"for task in RLBENCH10_TASKS:\n",
647+
" print(f\"Task: {task}\")\n",
648+
"\n",
649+
" dset = RLBenchPlacementDataset(\n",
650+
" dataset_root=\"/data/rlbench10/\",\n",
651+
" task_name=task,\n",
652+
" demos=[0],\n",
653+
" phase=\"all\",\n",
654+
" debugging=True,\n",
655+
" )\n",
656+
" data = dset[0]\n",
657+
" init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(data[\"initial_obs\"])\n",
658+
" handle_mapping= load_handle_mapping(\n",
659+
" dset.dataset_root, dset.task_name, dset.variation\n",
660+
" )\n",
661+
" inv_h_map = {v: k for k, v in handle_mapping.items()}\n",
662+
"\n",
663+
" for id in np.unique(init_mask):\n",
664+
" print(inv_h_map[id])\n",
665+
" unique_elements.add(inv_h_map[id])"
666+
]
667+
},
668+
{
669+
"cell_type": "code",
670+
"execution_count": null,
671+
"metadata": {},
672+
"outputs": [],
673+
"source": [
674+
"unique_elements"
675+
]
676+
},
454677
{
455678
"cell_type": "code",
456679
"execution_count": null,

0 commit comments

Comments
 (0)