|
18 | 18 | "source": [
|
19 | 19 | "from rpad.rlbench_utils.placement_dataset import RLBenchPlacementDataset, load_handle_mapping, load_state_pos_dict, TASK_DICT\n",
|
20 | 20 | "import numpy as np\n",
|
21 |
| - "\n", |
| 21 | + "from rpad.rlbench_utils.task_info import RLBENCH_10_TASKS\n", |
22 | 22 | "from rpad.visualize_3d.plots import segmentation_fig\n",
|
23 | 23 | "import torch\n",
|
24 | 24 | "import matplotlib.pyplot as plt\n",
|
|
32 | 32 | "outputs": [],
|
33 | 33 | "source": [
|
34 | 34 | "dset = RLBenchPlacementDataset(\n",
|
35 |
| - " dataset_root=\"/data/rlbench10/\",\n", |
| 35 | + " dataset_root=\"/data/rlbench10_collisions/\",\n", |
36 | 36 | " # task_name=\"stack_wine\",\n",
|
37 | 37 | " # task_name=\"insert_onto_square_peg\",\n",
|
38 | 38 | " # task_name=\"insert_usb_in_computer\",\n",
|
|
42 | 42 | " # task_name=\"solve_puzzle\",\n",
|
43 | 43 | " # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
|
44 | 44 | " task_name=\"slide_block_to_target\",\n",
|
45 |
| - " demos=range(10),\n", |
| 45 | + " demos=range(100),\n", |
46 | 46 | " phase=\"all\",\n",
|
47 | 47 | ")"
|
48 | 48 | ]
|
49 | 49 | },
|
| 50 | + { |
| 51 | + "cell_type": "code", |
| 52 | + "execution_count": null, |
| 53 | + "metadata": {}, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
| 56 | + "data = dset[0]" |
| 57 | + ] |
| 58 | + }, |
| 59 | + { |
| 60 | + "cell_type": "code", |
| 61 | + "execution_count": null, |
| 62 | + "metadata": {}, |
| 63 | + "outputs": [], |
| 64 | + "source": [ |
| 65 | + "data[\"ignore_collisions\"]" |
| 66 | + ] |
| 67 | + }, |
| 68 | + { |
| 69 | + "cell_type": "code", |
| 70 | + "execution_count": null, |
| 71 | + "metadata": {}, |
| 72 | + "outputs": [], |
| 73 | + "source": [ |
| 74 | + "data.keys()" |
| 75 | + ] |
| 76 | + }, |
50 | 77 | {
|
51 | 78 | "cell_type": "code",
|
52 | 79 | "execution_count": null,
|
|
707 | 734 | "unique_elements"
|
708 | 735 | ]
|
709 | 736 | },
|
| 737 | + { |
| 738 | + "cell_type": "code", |
| 739 | + "execution_count": null, |
| 740 | + "metadata": {}, |
| 741 | + "outputs": [], |
| 742 | + "source": [ |
| 743 | + "for task_name in RLBENCH_10_TASKS:\n", |
| 744 | + " print(\"--------------------\")\n", |
| 745 | + " print(f\"Task: {task_name}\")\n", |
| 746 | + " print(\"--------------------\")\n", |
| 747 | + " for phase in TASK_DICT[task_name][\"phase\"].keys():\n", |
| 748 | + "\n", |
| 749 | + " dset = RLBenchPlacementDataset(\n", |
| 750 | + " dataset_root=\"/data/rlbench10_collisions/\",\n", |
| 751 | + " # task_name=\"stack_wine\",\n", |
| 752 | + " # task_name=\"insert_onto_square_peg\",\n", |
| 753 | + " # task_name=\"insert_usb_in_computer\",\n", |
| 754 | + " # task_name=\"phone_on_base\",\n", |
| 755 | + " # task_name=\"put_toilet_roll_on_stand\",\n", |
| 756 | + " # task_name=\"place_hanger_on_rack\",\n", |
| 757 | + " # task_name=\"solve_puzzle\",\n", |
| 758 | + " # task_name=\"take_umbrella_out_of_umbrella_stand\",\n", |
| 759 | + " task_name=task_name,\n", |
| 760 | + " demos=range(100),\n", |
| 761 | + " phase=phase,\n", |
| 762 | + " ) \n", |
| 763 | + " ignore_collisions_all = []\n", |
| 764 | + " for i in range(len(dset)):\n", |
| 765 | + " try:\n", |
| 766 | + " data = dset[i]\n", |
| 767 | + " ignore_collisions_all.append(data[\"ignore_collisions\"])\n", |
| 768 | + " except:\n", |
| 769 | + " print(f\"Error in task {task_name}, phase {phase}, demo {i}\")\n", |
| 770 | + " ignore_all = (np.array(ignore_collisions_all).any()) \n", |
| 771 | + " print(f\"Phase: {phase}; Ignore Collisions: {ignore_all}\")\n", |
| 772 | + "\n", |
| 773 | + "\n" |
| 774 | + ] |
| 775 | + }, |
710 | 776 | {
|
711 | 777 | "cell_type": "code",
|
712 | 778 | "execution_count": null,
|
|
0 commit comments