|
21 | 21 | "\n",
|
22 | 22 | "from rpad.visualize_3d.plots import segmentation_fig\n",
|
23 | 23 | "import torch\n",
|
24 |
| - "import matplotlib.pyplot as plt" |
| 24 | + "import matplotlib.pyplot as plt\n", |
| 25 | + "from plotly.subplots import make_subplots\n" |
25 | 26 | ]
|
26 | 27 | },
|
27 | 28 | {
|
|
39 | 40 | " # task_name=\"put_toilet_roll_on_stand\",\n",
|
40 | 41 | " # task_name=\"place_hanger_on_rack\",\n",
|
41 | 42 | " # task_name=\"solve_puzzle\",\n",
|
42 |
| - " task_name=\"take_umbrella_out_of_umbrella_stand\",\n", |
| 43 | + " # task_name=\"take_umbrella_out_of_umbrella_stand\",\n", |
| 44 | + " task_name=\"slide_block_to_target\",\n", |
43 | 45 | " demos=range(10),\n",
|
44 | 46 | " phase=\"all\",\n",
|
45 | 47 | ")"
|
46 | 48 | ]
|
47 | 49 | },
|
| 50 | + { |
| 51 | + "cell_type": "code", |
| 52 | + "execution_count": null, |
| 53 | + "metadata": {}, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
| 56 | + "low_dim_state_dict = load_state_pos_dict(\n", |
| 57 | + " \"/data/rlbench10\", \"slide_block_to_target\", 0, 0\n", |
| 58 | + ")" |
| 59 | + ] |
| 60 | + }, |
| 61 | + { |
| 62 | + "cell_type": "code", |
| 63 | + "execution_count": null, |
| 64 | + "metadata": {}, |
| 65 | + "outputs": [], |
| 66 | + "source": [ |
| 67 | + "import rlbench\n", |
| 68 | + "from rlbench.observation_config import CameraConfig, ObservationConfig\n", |
| 69 | + "\n", |
| 70 | + "demo = rlbench.utils.get_stored_demos(\n", |
| 71 | + " amount=1,\n", |
| 72 | + " image_paths=False,\n", |
| 73 | + " dataset_root=\"/data/rlbench10\",\n", |
| 74 | + " variation_number=0,\n", |
| 75 | + " # task_name=\"slide_block_to_target\",\n", |
| 76 | + " # task_name=\"reach_target\",\n", |
| 77 | + " task_name=\"put_money_in_safe\",\n", |
| 78 | + " obs_config=ObservationConfig(\n", |
| 79 | + " left_shoulder_camera=CameraConfig(image_size=(256, 256)),\n", |
| 80 | + " right_shoulder_camera=CameraConfig(image_size=(256, 256)),\n", |
| 81 | + " front_camera=CameraConfig(image_size=(256, 256)),\n", |
| 82 | + " wrist_camera=CameraConfig(image_size=(256, 256)),\n", |
| 83 | + " overhead_camera=CameraConfig(image_size=(256, 256)),\n", |
| 84 | + " task_low_dim_state=True,\n", |
| 85 | + " ),\n", |
| 86 | + " random_selection=False,\n", |
| 87 | + " from_episode_number=0,\n", |
| 88 | + ")[0]" |
| 89 | + ] |
| 90 | + }, |
| 91 | + { |
| 92 | + "cell_type": "code", |
| 93 | + "execution_count": null, |
| 94 | + "metadata": {}, |
| 95 | + "outputs": [], |
| 96 | + "source": [ |
| 97 | + "dir(demo[0])" |
| 98 | + ] |
| 99 | + }, |
| 100 | + { |
| 101 | + "cell_type": "code", |
| 102 | + "execution_count": null, |
| 103 | + "metadata": {}, |
| 104 | + "outputs": [], |
| 105 | + "source": [ |
| 106 | + "demo[0].task_low_dim_state.shape" |
| 107 | + ] |
| 108 | + }, |
| 109 | + { |
| 110 | + "cell_type": "code", |
| 111 | + "execution_count": null, |
| 112 | + "metadata": {}, |
| 113 | + "outputs": [], |
| 114 | + "source": [ |
| 115 | + "print(list(low_dim_state_dict.items()))" |
| 116 | + ] |
| 117 | + }, |
48 | 118 | {
|
49 | 119 | "cell_type": "markdown",
|
50 | 120 | "metadata": {},
|
|
67 | 137 | " # task_name=\"put_toilet_roll_on_stand\",\n",
|
68 | 138 | " # task_name=\"place_hanger_on_rack\",\n",
|
69 | 139 | " # task_name=\"solve_puzzle\",\n",
|
70 |
| - " task_name=\"take_umbrella_out_of_umbrella_stand\",\n", |
| 140 | + " # task_name=\"take_umbrella_out_of_umbrella_stand\",\n", |
| 141 | + " task_name=\"slide_block_to_target\",\n", |
71 | 142 | " demos=[0],\n",
|
72 | 143 | " phase=\"all\",\n",
|
73 | 144 | " use_first_as_init_keyframe=False,\n",
|
|
121 | 192 | " axes[i // 5, i % 5].set_title(f\"Demo {i}\")"
|
122 | 193 | ]
|
123 | 194 | },
|
| 195 | + { |
| 196 | + "cell_type": "code", |
| 197 | + "execution_count": null, |
| 198 | + "metadata": {}, |
| 199 | + "outputs": [], |
| 200 | + "source": [ |
| 201 | + "RLBENCH10_TASKS = [\n", |
| 202 | + " \"pick_and_lift\",\n", |
| 203 | + " \"put_knife_on_chopping_board\",\n", |
| 204 | + " \"take_money_out_safe\",\n", |
| 205 | + " \"pick_up_cup\",\n", |
| 206 | + " \"put_money_in_safe\",\n", |
| 207 | + " \"slide_block_to_target\",\n", |
| 208 | + " \"take_umbrella_out_of_umbrella_stand\",\n", |
| 209 | + " \"push_button\",\n", |
| 210 | + " \"reach_target\",\n", |
| 211 | + " \"stack_wine\",\n", |
| 212 | + " \n", |
| 213 | + "]" |
| 214 | + ] |
| 215 | + }, |
124 | 216 | {
|
125 | 217 | "cell_type": "markdown",
|
126 | 218 | "metadata": {},
|
|
271 | 363 | "metadata": {},
|
272 | 364 | "outputs": [],
|
273 | 365 | "source": [
|
| 366 | + "task_name = \"put_money_in_safe\"\n", |
274 | 367 | "state_pos_dict = load_state_pos_dict(\"/data/rlbench10/\", task_name, 0, 0)\n",
|
275 | 368 | "state_pos_dict"
|
276 | 369 | ]
|
|
303 | 396 | " fig: Optional[go.Figure] = None,\n",
|
304 | 397 | " row: int = 1,\n",
|
305 | 398 | " column: int = 1,\n",
|
| 399 | + " n_col: int = 5,\n", |
306 | 400 | "):\n",
|
307 | 401 | " \"\"\"Creates a segmentation figure.\"\"\"\n",
|
308 | 402 | " # Create a figure.\n",
|
309 | 403 | " if fig is None:\n",
|
310 | 404 | " fig = go.Figure()\n",
|
311 | 405 | "\n",
|
312 |
| - " scene_num = (row-1) * 5 + column\n", |
| 406 | + " scene_num = (row-1) * n_col + column\n", |
313 | 407 | "\n",
|
314 | 408 | " fig.add_traces(_segmentation_traces(data, labels, labelmap, f\"scene{scene_num}\", sizes), rows=row, cols=column)\n",
|
315 | 409 | "\n",
|
|
333 | 427 | "data.keys()"
|
334 | 428 | ]
|
335 | 429 | },
|
| 430 | + { |
| 431 | + "cell_type": "code", |
| 432 | + "execution_count": null, |
| 433 | + "metadata": {}, |
| 434 | + "outputs": [], |
| 435 | + "source": [ |
| 436 | + "from plotly.subplots import make_subplots\n" |
| 437 | + ] |
| 438 | + }, |
336 | 439 | {
|
337 | 440 | "cell_type": "code",
|
338 | 441 | "execution_count": null,
|
|
341 | 444 | "source": [
|
342 | 445 | "# For each phase, plot the segmentation.\n",
|
343 | 446 | "\n",
|
344 |
| - "from plotly.subplots import make_subplots\n", |
345 | 447 | "\n",
|
346 | 448 | "phase = list(TASK_DICT[\"take_umbrella_out_of_umbrella_stand\"][\"phase\"].keys())[0]\n",
|
347 | 449 | "\n",
|
|
386 | 488 | "execution_count": null,
|
387 | 489 | "metadata": {},
|
388 | 490 | "outputs": [],
|
389 |
| - "source": [] |
| 491 | + "source": [ |
| 492 | + "demo[0].task_low_dim_state" |
| 493 | + ] |
390 | 494 | },
|
391 | 495 | {
|
392 | 496 | "cell_type": "code",
|
393 | 497 | "execution_count": null,
|
394 | 498 | "metadata": {},
|
395 | 499 | "outputs": [],
|
396 | 500 | "source": [
|
397 |
| - "# task_name = \"stack_wine\"\n", |
398 |
| - "# task_name = \"reach_target\"\n", |
| 501 | + "# - \"pick_and_lift\",\n", |
| 502 | + "# - \"pick_up_cup\",\n", |
| 503 | + "# - \"put_knife_on_chopping_board\",\n", |
| 504 | + "# - \"put_money_in_safe\",\n", |
| 505 | + "# - \"push_button\",\n", |
| 506 | + "# - \"reach_target\",\n", |
| 507 | + "# - \"slide_block_to_target\",\n", |
| 508 | + "# - \"stack_wine\",\n", |
| 509 | + "# - \"take_money_out_safe\",\n", |
| 510 | + "# - \"take_umbrella_out_of_umbrella_stand\",\n", |
| 511 | + "\n", |
| 512 | + "# task_name = \"pick_and_lift\"\n", |
| 513 | + "# task_name = \"pick_up_cup\"\n", |
| 514 | + "# task_name = \"put_knife_on_chopping_board\"\n", |
399 | 515 | "# task_name = \"put_money_in_safe\"\n",
|
| 516 | + "# task_name = \"push_button\"\n", |
| 517 | + "# task_name = \"reach_target\"\n", |
| 518 | + "# task_name = \"slide_block_to_target\"\n", |
| 519 | + "task_name = \"stack_wine\"\n", |
400 | 520 | "# task_name = \"take_money_out_safe\"\n",
|
401 |
| - "task_name = \"put_knife_on_chopping_board\"\n", |
| 521 | + "# task_name = \"take_umbrella_out_of_umbrella_stand\"\n", |
| 522 | + "\n", |
| 523 | + "n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n", |
| 524 | + "fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n", |
402 | 525 | "\n",
|
403 |
| - "for phase in TASK_DICT[task_name][\"phase_order\"]:\n", |
| 526 | + "for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n", |
404 | 527 | " print(f\"Phase: {phase}\")\n",
|
405 | 528 | " dset = RLBenchPlacementDataset(\n",
|
406 | 529 | " dataset_root=\"/data/rlbench10/\",\n",
|
407 | 530 | " task_name=task_name,\n",
|
408 |
| - " n_demos=1,\n", |
| 531 | + " demos=[0],\n", |
409 | 532 | " phase=phase,\n",
|
410 | 533 | " debugging=False,\n",
|
411 | 534 | " use_first_as_init_keyframe=False,\n",
|
| 535 | + " anchor_mode=\"background_robot_removed\",\n", |
| 536 | + " action_mode=\"gripper_and_object\",\n", |
412 | 537 | " )\n",
|
413 | 538 | "\n",
|
414 | 539 | " data = dset[0]\n",
|
|
417 | 542 | "\n",
|
418 | 543 | " print(list(data.keys()))\n",
|
419 | 544 | "\n",
|
| 545 | + " anchor_pc = data[\"init_anchor_pc\"]\n", |
| 546 | + " # Randomly downsample the anchor point cloud.\n", |
| 547 | + " n_pts = anchor_pc.shape[0]\n", |
| 548 | + " if n_pts > 1000:\n", |
| 549 | + " anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n", |
| 550 | + "\n", |
420 | 551 | " points = torch.cat(\n",
|
421 | 552 | " [\n",
|
422 | 553 | " data[\"init_action_pc\"],\n",
|
423 |
| - " data[\"init_anchor_pc\"],\n", |
| 554 | + " anchor_pc,\n", |
424 | 555 | " data[\"key_action_pc\"],\n",
|
425 | 556 | " ]\n",
|
426 | 557 | " )\n",
|
427 | 558 | " print(points.shape)\n",
|
428 | 559 | " seg = torch.cat(\n",
|
429 | 560 | " [\n",
|
430 | 561 | " torch.zeros(data[\"init_action_pc\"].shape[0]),\n",
|
431 |
| - " torch.ones(data[\"init_anchor_pc\"].shape[0]),\n", |
| 562 | + " torch.ones(anchor_pc.shape[0]),\n", |
432 | 563 | " 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n",
|
433 | 564 | " ]\n",
|
434 | 565 | " )\n",
|
435 |
| - " fig = segmentation_fig(\n", |
| 566 | + "\n", |
| 567 | + " fig = segmentation_fig_rc(\n", |
436 | 568 | " points,\n",
|
437 | 569 | " seg.int(),\n",
|
438 | 570 | " labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n",
|
| 571 | + " fig=fig,\n", |
| 572 | + " row=1,\n", |
| 573 | + " column=ix+1,\n", |
| 574 | + " n_col=n_phases,\n", |
439 | 575 | " )\n",
|
440 |
| - " fig.show()\n", |
441 | 576 | "\n",
|
| 577 | + "fig.show()\n", |
442 | 578 | " "
|
443 | 579 | ]
|
444 | 580 | },
|
|
451 | 587 | "!ls /data/rlbench10/put_money_in_safe/variation0/episodes/episode0"
|
452 | 588 | ]
|
453 | 589 | },
|
| 590 | + { |
| 591 | + "cell_type": "code", |
| 592 | + "execution_count": null, |
| 593 | + "metadata": {}, |
| 594 | + "outputs": [], |
| 595 | + "source": [ |
| 596 | + "# Debugging dataset.\n", |
| 597 | + "\n" |
| 598 | + ] |
| 599 | + }, |
| 600 | + { |
| 601 | + "cell_type": "code", |
| 602 | + "execution_count": null, |
| 603 | + "metadata": {}, |
| 604 | + "outputs": [], |
| 605 | + "source": [] |
| 606 | + }, |
| 607 | + { |
| 608 | + "cell_type": "code", |
| 609 | + "execution_count": null, |
| 610 | + "metadata": {}, |
| 611 | + "outputs": [], |
| 612 | + "source": [ |
| 613 | + "np.unique(data[\"init_front_mask\"])" |
| 614 | + ] |
| 615 | + }, |
| 616 | + { |
| 617 | + "cell_type": "code", |
| 618 | + "execution_count": null, |
| 619 | + "metadata": {}, |
| 620 | + "outputs": [], |
| 621 | + "source": [ |
| 622 | + "\n", |
| 623 | + "\n", |
| 624 | + "from rpad.rlbench_utils.placement_dataset import obs_to_rgb_point_cloud\n", |
| 625 | + "\n", |
| 626 | + "\n" |
| 627 | + ] |
| 628 | + }, |
| 629 | + { |
| 630 | + "cell_type": "code", |
| 631 | + "execution_count": null, |
| 632 | + "metadata": {}, |
| 633 | + "outputs": [], |
| 634 | + "source": [] |
| 635 | + }, |
| 636 | + { |
| 637 | + "cell_type": "code", |
| 638 | + "execution_count": null, |
| 639 | + "metadata": {}, |
| 640 | + "outputs": [], |
| 641 | + "source": [ |
| 642 | + "\n", |
| 643 | + "unique_elements = set()\n", |
| 644 | + "\n", |
| 645 | + "\n", |
| 646 | + "for task in RLBENCH10_TASKS:\n", |
| 647 | + " print(f\"Task: {task}\")\n", |
| 648 | + "\n", |
| 649 | + " dset = RLBenchPlacementDataset(\n", |
| 650 | + " dataset_root=\"/data/rlbench10/\",\n", |
| 651 | + " task_name=task,\n", |
| 652 | + " demos=[0],\n", |
| 653 | + " phase=\"all\",\n", |
| 654 | + " debugging=True,\n", |
| 655 | + " )\n", |
| 656 | + " data = dset[0]\n", |
| 657 | + " init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(data[\"initial_obs\"])\n", |
| 658 | + " handle_mapping= load_handle_mapping(\n", |
| 659 | + " dset.dataset_root, dset.task_name, dset.variation\n", |
| 660 | + " )\n", |
| 661 | + " inv_h_map = {v: k for k, v in handle_mapping.items()}\n", |
| 662 | + "\n", |
| 663 | + " for id in np.unique(init_mask):\n", |
| 664 | + " print(inv_h_map[id])\n", |
| 665 | + " unique_elements.add(inv_h_map[id])" |
| 666 | + ] |
| 667 | + }, |
| 668 | + { |
| 669 | + "cell_type": "code", |
| 670 | + "execution_count": null, |
| 671 | + "metadata": {}, |
| 672 | + "outputs": [], |
| 673 | + "source": [ |
| 674 | + "unique_elements" |
| 675 | + ] |
| 676 | + }, |
454 | 677 | {
|
455 | 678 | "cell_type": "code",
|
456 | 679 | "execution_count": null,
|
|
0 commit comments