Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Integration with torchRL #23

@ShahRutav

Description

@ShahRutav

I modified the getting started example to run torchrl with robohive. Here's the modified example,

import torch
import robohive
from torchrl.envs import RoboHiveEnv
from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform

from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))

Additionally, I changed this line to filter out the visual keys while concatenating R3M transform with other keys to

vec_keys = [k for k in base_env.observation_spec.keys() if ((k != "pixels") and ("visual" not in k))]

This leads to an error -

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    print(env.rollout(3))
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1797, in rollout
    tensordict = self.reset()
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1480, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 760, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 1020, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3694, in _reset
    tensordict_reset = self._call(tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3676, in _call
    out_tensor = torch.cat(values, dim=self.dim)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 2785, in __torch_function__
    return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 5346, in _cat
    batch_size = list(list_of_tensordicts[0].batch_size)
AttributeError: 'Tensor' object has no attribute 'batch_size'

I am using the following versions of packages: robohive==0.6.0 tensordict==0.2.1 torchrl==0.2.1. Which version did you use? @vmoens

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions