Skip to content

[BUG] DataCollectors fail when device is set to MPS #2858

@LCarmi

Description

@LCarmi

Describe the bug

When running experiments with multiprocess-based sampling of trajectories on macOS, the initialization of the data collectors fail

To Reproduce

from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncDataCollector

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
    collector = MultiSyncDataCollector(
        create_env_fn=[env_maker, env_maker],
        policy=policy,
        total_frames=2000,
        max_frames_per_traj=50,
        frames_per_batch=200,
        init_random_frames=-1,
        reset_at_each_iter=False,
        device="mps",
        storing_device="cpu",
        # cat_results="stack",
    )
    for i, data in enumerate(collector):
        if i == 2:
            print(data)
            break

This fails as follows:

Traceback (most recent call last):
  File "..././torchrl_test_mps_fail.py", line 9, in <module>
    collector = MultiSyncDataCollector(
  File ".../.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1779, in __init__
    self._run_processes()
  File ".../.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1976, in _run_processes
    proc.start()
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File ".../.venv/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 607, in reduce_storage
    metadata = storage._share_filename_cpu_()
  File ".../.venv/lib/python3.10/site-packages/torch/storage.py", line 450, in wrapper
    return fn(self, *args, **kwargs)
  File ".../.venv/lib/python3.10/site-packages/torch/storage.py", line 529, in _share_filename_cpu_
    return super()._share_filename_cpu_(*args, **kwargs)
RuntimeError: _share_filename_: only available on CPU

System info

>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.2 2.2.4 3.10.16 (main, Dec  3 2024, 17:27:57) [Clang 16.0.6 ] darwin

Reason and Possible fixes

I suspect this issue boils down to:

  • limitations of mps device, which does not work well with a pickle-based sharing of parameters
  • limitations of torchrl , which assume a spawn-based multiprocessing library
    • as opposed to a fork-based multiprocess context; forcing fork through multiprocessing.set_start_method('fork') gives a warning and makes collectors crash
    • a spawn context is imposed by torchrl
      mp.set_start_method("spawn")
  • spawn multiprocessing context using pickle to copy the state of a process on a newly spawned one

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions