|
374 | 374 | "outputs": [],
|
375 | 375 | "source": [
|
376 | 376 | "# Get a mapping from handle id to handle name.\n",
|
377 |
| - "task_name = \"put_money_in_safe\"\n", |
| 377 | + "# task_name = \"put_money_in_safe\"\n", |
| 378 | + "task_name = \"put_knife_on_chopping_board\"\n", |
378 | 379 | "handle_mapping = load_handle_mapping(\"/data/rlbench10_collisions/\", task_name, 0)\n",
|
379 | 380 | "rev_handle_mapping = {v: k for k, v in handle_mapping.items()}\n",
|
380 | 381 | "\n",
|
|
388 | 389 | "metadata": {},
|
389 | 390 | "outputs": [],
|
390 | 391 | "source": [
|
391 |
| - "handle_mapping" |
| 392 | + "set(handle_mapping.keys())" |
392 | 393 | ]
|
393 | 394 | },
|
394 | 395 | {
|
|
563 | 564 | "# task_name = \"take_money_out_safe\"\n",
|
564 | 565 | "# task_name = \"take_umbrella_out_of_umbrella_stand\"\n",
|
565 | 566 | "\n",
|
566 |
| - "n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n", |
567 |
| - "fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n", |
| 567 | + "for i in range(4):\n", |
568 | 568 | "\n",
|
569 |
| - "for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n", |
570 |
| - " print(f\"Phase: {phase}\")\n", |
571 |
| - " dset = RLBenchPlacementDataset(\n", |
572 |
| - " dataset_root=\"/data/rlbench10_collisions/\",\n", |
573 |
| - " task_name=task_name,\n", |
574 |
| - " demos=[0],\n", |
575 |
| - " phase=phase,\n", |
576 |
| - " debugging=False,\n", |
577 |
| - " use_first_as_init_keyframe=False,\n", |
578 |
| - " anchor_mode=\"background_robot_removed\",\n", |
579 |
| - " action_mode=\"gripper_and_object\",\n", |
580 |
| - " include_wrist_cam=True,\n", |
581 |
| - " gripper_in_first_phase=True,\n", |
582 |
| - " )\n", |
| 569 | + " n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n", |
| 570 | + " fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n", |
583 | 571 | "\n",
|
584 |
| - " data = dset[0]\n", |
585 |
| - "\n", |
586 |
| - " # Plot segmentation with segmentation_fig\n", |
587 |
| - "\n", |
588 |
| - " print(list(data.keys()))\n", |
589 |
| - "\n", |
590 |
| - " anchor_pc = data[\"init_anchor_pc\"]\n", |
591 |
| - " # Randomly downsample the anchor point cloud.\n", |
592 |
| - " n_pts = anchor_pc.shape[0]\n", |
593 |
| - " if n_pts > 1000:\n", |
594 |
| - " anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n", |
595 |
| - "\n", |
596 |
| - " points = torch.cat(\n", |
597 |
| - " [\n", |
598 |
| - " data[\"init_action_pc\"],\n", |
599 |
| - " anchor_pc,\n", |
600 |
| - " data[\"key_action_pc\"],\n", |
601 |
| - " ]\n", |
602 |
| - " )\n", |
603 |
| - " print(points.shape)\n", |
604 |
| - " seg = torch.cat(\n", |
605 |
| - " [\n", |
606 |
| - " torch.zeros(data[\"init_action_pc\"].shape[0]),\n", |
607 |
| - " torch.ones(anchor_pc.shape[0]),\n", |
608 |
| - " 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n", |
609 |
| - " ]\n", |
610 |
| - " )\n", |
| 572 | + " for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n", |
| 573 | + " print(f\"Phase: {phase}\")\n", |
| 574 | + " dset = RLBenchPlacementDataset(\n", |
| 575 | + " dataset_root=\"/data/rlbench10_collisions/\",\n", |
| 576 | + " task_name=task_name,\n", |
| 577 | + " demos=range(100),\n", |
| 578 | + " phase=phase,\n", |
| 579 | + " debugging=False,\n", |
| 580 | + " use_first_as_init_keyframe=False,\n", |
| 581 | + " anchor_mode=\"background_robot_removed\",\n", |
| 582 | + " action_mode=\"gripper_and_object\",\n", |
| 583 | + " include_wrist_cam=True,\n", |
| 584 | + " gripper_in_first_phase=True,\n", |
| 585 | + " )\n", |
611 | 586 | "\n",
|
612 |
| - " fig = segmentation_fig_rc(\n", |
613 |
| - " points,\n", |
614 |
| - " seg.int(),\n", |
615 |
| - " labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n", |
616 |
| - " fig=fig,\n", |
617 |
| - " row=1,\n", |
618 |
| - " column=ix+1,\n", |
619 |
| - " n_col=n_phases,\n", |
620 |
| - " )\n", |
| 587 | + " data = dset[i]\n", |
621 | 588 | "\n",
|
622 |
| - "fig.show()\n", |
| 589 | + " # Plot segmentation with segmentation_fig\n", |
| 590 | + "\n", |
| 591 | + " print(list(data.keys()))\n", |
| 592 | + "\n", |
| 593 | + " anchor_pc = data[\"init_anchor_pc\"]\n", |
| 594 | + " # Randomly downsample the anchor point cloud.\n", |
| 595 | + " n_pts = anchor_pc.shape[0]\n", |
| 596 | + " if n_pts > 1000:\n", |
| 597 | + " anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n", |
| 598 | + "\n", |
| 599 | + " points = torch.cat(\n", |
| 600 | + " [\n", |
| 601 | + " data[\"init_action_pc\"],\n", |
| 602 | + " anchor_pc,\n", |
| 603 | + " data[\"key_action_pc\"],\n", |
| 604 | + " ]\n", |
| 605 | + " )\n", |
| 606 | + " print(points.shape)\n", |
| 607 | + " seg = torch.cat(\n", |
| 608 | + " [\n", |
| 609 | + " torch.zeros(data[\"init_action_pc\"].shape[0]),\n", |
| 610 | + " torch.ones(anchor_pc.shape[0]),\n", |
| 611 | + " 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n", |
| 612 | + " ]\n", |
| 613 | + " )\n", |
| 614 | + "\n", |
| 615 | + " fig = segmentation_fig_rc(\n", |
| 616 | + " points,\n", |
| 617 | + " seg.int(),\n", |
| 618 | + " labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n", |
| 619 | + " fig=fig,\n", |
| 620 | + " row=1,\n", |
| 621 | + " column=ix+1,\n", |
| 622 | + " n_col=n_phases,\n", |
| 623 | + " )\n", |
| 624 | + "\n", |
| 625 | + " fig.show()\n", |
623 | 626 | " "
|
624 | 627 | ]
|
625 | 628 | },
|
|
0 commit comments