This repository provides a JAX-based implementation for fine-tuning and inference with Google's RT-1-X model. It supports custom environments for fine-tuning and inference and includes integration with the MetaWorld environment.
- Fine-tuning: Easily fine-tune the RT-1 model using custom datasets.
- Inference: Run inference in custom environments.
- Environment Integration: Pre-integrated with the MetaWorld environment.
The RT-1 model is a versatile architecture designed for robotics tasks, making it ideal for various real-world scenarios.
A jax checkpoint that can be used by the flax checkpoint loader can be downloaded by
gsutil -m cp -r gs://gdm-robotics-open-x-embodiment/open_x_embodiment_and_rt_x_oss/rt_1_x_jax .
Fine-tuning uses datasets in the RLDS format. To create or modify datasets:
- Refer to the
data/datasets.py
file to define your custom dataset. - Use RLDS-compatible tools for dataset preparation. For example, you can refer to the data collection script from the RLDS Dataset Builder repository: rlds_dataset_builder.
Once your dataset is ready, execute the following command:
python train.py
To customize environments for inference:
- Define your custom environment in the
envs/
directory. - Integrate the environment into
eval.py
for inference.
After integrating your environment, execute the following command:
python eval.py
This repository is pre-integrated with the MetaWorld environment. Additional environments can be added by following the instructions in the envs/
directory.
- JAX
- RLDS
- Python 3.10
Install the required dependencies:
pip install -r requirements.txt
Install jax for GPU
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Install jax for TPU
pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
.
├── data
│ ├── datasets.py # Custom dataset definitions
├── envs
│ ├── metaworld # MetaWorld environment integration
│ └── ... # Add your custom environments here
├── train.py # Fine-tuning script
├── eval.py # Inference script
└── requirements.txt # Python dependencies
Thanks to Google for the RT-1 model and the robotics research community for inspiration and tools.