Skip to content

Commit 657288a

Browse files
committed
ignore collisions
1 parent d2ad10c commit 657288a

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

notebooks/explore_dset2.ipynb

+69-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"source": [
1919
"from rpad.rlbench_utils.placement_dataset import RLBenchPlacementDataset, load_handle_mapping, load_state_pos_dict, TASK_DICT\n",
2020
"import numpy as np\n",
21-
"\n",
21+
"from rpad.rlbench_utils.task_info import RLBENCH_10_TASKS\n",
2222
"from rpad.visualize_3d.plots import segmentation_fig\n",
2323
"import torch\n",
2424
"import matplotlib.pyplot as plt\n",
@@ -32,7 +32,7 @@
3232
"outputs": [],
3333
"source": [
3434
"dset = RLBenchPlacementDataset(\n",
35-
" dataset_root=\"/data/rlbench10/\",\n",
35+
" dataset_root=\"/data/rlbench10_collisions/\",\n",
3636
" # task_name=\"stack_wine\",\n",
3737
" # task_name=\"insert_onto_square_peg\",\n",
3838
" # task_name=\"insert_usb_in_computer\",\n",
@@ -42,11 +42,38 @@
4242
" # task_name=\"solve_puzzle\",\n",
4343
" # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
4444
" task_name=\"slide_block_to_target\",\n",
45-
" demos=range(10),\n",
45+
" demos=range(100),\n",
4646
" phase=\"all\",\n",
4747
")"
4848
]
4949
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"data = dset[0]"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"data[\"ignore_collisions\"]"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"data.keys()"
75+
]
76+
},
5077
{
5178
"cell_type": "code",
5279
"execution_count": null,
@@ -707,6 +734,45 @@
707734
"unique_elements"
708735
]
709736
},
737+
{
738+
"cell_type": "code",
739+
"execution_count": null,
740+
"metadata": {},
741+
"outputs": [],
742+
"source": [
743+
"for task_name in RLBENCH_10_TASKS:\n",
744+
" print(\"--------------------\")\n",
745+
" print(f\"Task: {task_name}\")\n",
746+
" print(\"--------------------\")\n",
747+
" for phase in TASK_DICT[task_name][\"phase\"].keys():\n",
748+
"\n",
749+
" dset = RLBenchPlacementDataset(\n",
750+
" dataset_root=\"/data/rlbench10_collisions/\",\n",
751+
" # task_name=\"stack_wine\",\n",
752+
" # task_name=\"insert_onto_square_peg\",\n",
753+
" # task_name=\"insert_usb_in_computer\",\n",
754+
" # task_name=\"phone_on_base\",\n",
755+
" # task_name=\"put_toilet_roll_on_stand\",\n",
756+
" # task_name=\"place_hanger_on_rack\",\n",
757+
" # task_name=\"solve_puzzle\",\n",
758+
" # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
759+
" task_name=task_name,\n",
760+
" demos=range(100),\n",
761+
" phase=phase,\n",
762+
" ) \n",
763+
" ignore_collisions_all = []\n",
764+
" for i in range(len(dset)):\n",
765+
" try:\n",
766+
" data = dset[i]\n",
767+
" ignore_collisions_all.append(data[\"ignore_collisions\"])\n",
768+
" except:\n",
769+
" print(f\"Error in task {task_name}, phase {phase}, demo {i}\")\n",
770+
" ignore_all = (np.array(ignore_collisions_all).any()) \n",
771+
" print(f\"Phase: {phase}; Ignore Collisions: {ignore_all}\")\n",
772+
"\n",
773+
"\n"
774+
]
775+
},
710776
{
711777
"cell_type": "code",
712778
"execution_count": null,

src/rpad/rlbench_utils/placement_dataset.py

+7
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ def extract_pose(obs, key):
467467
T_init_key = T_action_key_world @ np.linalg.inv(T_action_init_world)
468468
T_anchor_key_world = extract_pose(key_obs, "anchor_pose_name")
469469

470+
if hasattr(initial_obs, "ignore_collisions"):
471+
ignore_collisions = initial_obs.ignore_collisions
472+
ignore_collisions = torch.from_numpy(ignore_collisions.astype(np.int32))
473+
else:
474+
ignore_collisions = None
475+
470476
return {
471477
"init_action_rgb": torch.from_numpy(init_action_rgb),
472478
"init_action_pc": torch.from_numpy(init_action_point_cloud),
@@ -489,4 +495,5 @@ def extract_pose(obs, key):
489495
"key_front_mask": torch.from_numpy(key_obs.front_mask.astype(np.int32)),
490496
"phase": phase,
491497
"phase_onehot": torch.from_numpy(phase_onehot),
498+
"ignore_collisions": ignore_collisions,
492499
}

0 commit comments

Comments
 (0)