This repository was archived by the owner on Jun 2, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
Integration with torchRL #23
Copy link
Copy link
Open
Description
I modified the getting started example to run torchrl with robohive. Here's the modified example,
import torch
import robohive
from torchrl.envs import RoboHiveEnv
from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform
from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
device = torch.device("cpu") # could be 'cuda:0'
env_name = 'FrankaReachFixed-v0'
env = make_r3m_env(env_name, model_name="resnet18", download=True)
assert env.device == device
# example of a rollout
print(env.rollout(3))
Additionally, I changed this line to filter out the visual keys while concatenating R3M transform with other keys to
vec_keys = [k for k in base_env.observation_spec.keys() if ((k != "pixels") and ("visual" not in k))]
This leads to an error -
Traceback (most recent call last):
File "test.py", line 16, in <module>
print(env.rollout(3))
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1797, in rollout
tensordict = self.reset()
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1480, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 760, in _reset
tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 1020, in _reset
tensordict_reset = t._reset(tensordict, tensordict_reset)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3694, in _reset
tensordict_reset = self._call(tensordict_reset)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3676, in _call
out_tensor = torch.cat(values, dim=self.dim)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 2785, in __torch_function__
return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 5346, in _cat
batch_size = list(list_of_tensordicts[0].batch_size)
AttributeError: 'Tensor' object has no attribute 'batch_size'
I am using the following versions of packages: robohive==0.6.0 tensordict==0.2.1 torchrl==0.2.1. Which version did you use? @vmoens
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels