|
47 | 47 | ")"
|
48 | 48 | ]
|
49 | 49 | },
|
50 |
| - { |
51 |
| - "cell_type": "code", |
52 |
| - "execution_count": null, |
53 |
| - "metadata": {}, |
54 |
| - "outputs": [], |
55 |
| - "source": [ |
56 |
| - "data = dset[0]" |
57 |
| - ] |
58 |
| - }, |
59 |
| - { |
60 |
| - "cell_type": "code", |
61 |
| - "execution_count": null, |
62 |
| - "metadata": {}, |
63 |
| - "outputs": [], |
64 |
| - "source": [ |
65 |
| - "data[\"ignore_collisions\"]" |
66 |
| - ] |
67 |
| - }, |
68 |
| - { |
69 |
| - "cell_type": "code", |
70 |
| - "execution_count": null, |
71 |
| - "metadata": {}, |
72 |
| - "outputs": [], |
73 |
| - "source": [ |
74 |
| - "data.keys()" |
75 |
| - ] |
76 |
| - }, |
77 | 50 | {
|
78 | 51 | "cell_type": "code",
|
79 | 52 | "execution_count": null,
|
|
401 | 374 | "outputs": [],
|
402 | 375 | "source": [
|
403 | 376 | "# Get a mapping from handle id to handle name.\n",
|
404 |
| - "handle_mapping = load_handle_mapping(\"/data/rlbench10/\", task_name, 0)\n", |
| 377 | + "task_name = \"put_money_in_safe\"\n", |
| 378 | + "handle_mapping = load_handle_mapping(\"/data/rlbench10_collisions/\", task_name, 0)\n", |
405 | 379 | "rev_handle_mapping = {v: k for k, v in handle_mapping.items()}\n",
|
406 | 380 | "\n",
|
407 | 381 | "q_id = 100\n",
|
408 | 382 | "rev_handle_mapping[q_id]"
|
409 | 383 | ]
|
410 | 384 | },
|
| 385 | + { |
| 386 | + "cell_type": "code", |
| 387 | + "execution_count": null, |
| 388 | + "metadata": {}, |
| 389 | + "outputs": [], |
| 390 | + "source": [ |
| 391 | + "handle_mapping" |
| 392 | + ] |
| 393 | + }, |
411 | 394 | {
|
412 | 395 | "cell_type": "code",
|
413 | 396 | "execution_count": null,
|
|
510 | 493 | "N_DEMOS = 10\n",
|
511 | 494 | "# Create a dataset for that phase.\n",
|
512 | 495 | "dset = RLBenchPlacementDataset(\n",
|
513 |
| - " dataset_root=\"/data/rlbench10/\",\n", |
| 496 | + " dataset_root=\"/data/rlbench10_collisions/\",\n", |
514 | 497 | " task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
|
515 | 498 | " demos=range(N_DEMOS),\n",
|
516 | 499 | " phase=phase,\n",
|
|
571 | 554 | "\n",
|
572 | 555 | "# task_name = \"pick_and_lift\"\n",
|
573 | 556 | "# task_name = \"pick_up_cup\"\n",
|
574 |
| - "# task_name = \"put_knife_on_chopping_board\"\n", |
| 557 | + "task_name = \"put_knife_on_chopping_board\"\n", |
575 | 558 | "# task_name = \"put_money_in_safe\"\n",
|
576 | 559 | "# task_name = \"push_button\"\n",
|
577 | 560 | "# task_name = \"reach_target\"\n",
|
578 | 561 | "# task_name = \"slide_block_to_target\"\n",
|
579 |
| - "task_name = \"stack_wine\"\n", |
| 562 | + "# task_name = \"stack_wine\"\n", |
580 | 563 | "# task_name = \"take_money_out_safe\"\n",
|
581 | 564 | "# task_name = \"take_umbrella_out_of_umbrella_stand\"\n",
|
582 | 565 | "\n",
|
|
586 | 569 | "for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n",
|
587 | 570 | " print(f\"Phase: {phase}\")\n",
|
588 | 571 | " dset = RLBenchPlacementDataset(\n",
|
589 |
| - " dataset_root=\"/data/rlbench10/\",\n", |
| 572 | + " dataset_root=\"/data/rlbench10_collisions/\",\n", |
590 | 573 | " task_name=task_name,\n",
|
591 | 574 | " demos=[0],\n",
|
592 | 575 | " phase=phase,\n",
|
593 | 576 | " debugging=False,\n",
|
594 | 577 | " use_first_as_init_keyframe=False,\n",
|
595 | 578 | " anchor_mode=\"background_robot_removed\",\n",
|
596 | 579 | " action_mode=\"gripper_and_object\",\n",
|
| 580 | + " include_wrist_cam=True,\n", |
| 581 | + " gripper_in_first_phase=True,\n", |
597 | 582 | " )\n",
|
598 | 583 | "\n",
|
599 | 584 | " data = dset[0]\n",
|
|
0 commit comments