diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index ee87df39ac1..6c53b5ebabe 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1392,7 +1392,7 @@ def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, _data in enumerate(tensordict.unbind(0)): + for i, _data in enumerate(tensordict.consolidate().unbind(0)): self.parent_channels[i].send(("step_and_maybe_reset", _data)) results = [None] * self.num_workers @@ -1489,7 +1489,7 @@ def step_and_maybe_reset( def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, data in enumerate(tensordict.unbind(0)): + for i, data in enumerate(tensordict.consolidate().unbind(0)): self.parent_channels[i].send(("step", data)) out_tds = [] for i, channel in enumerate(self.parent_channels): @@ -1576,7 +1576,7 @@ def _reset_no_buffers( needs_resetting, ) -> Tuple[TensorDictBase, TensorDictBase]: tdunbound = ( - tensordict.unbind(0) + tensordict.consolidate().unbind(0) if is_tensor_collection(tensordict) else [None] * self.num_workers ) @@ -1895,10 +1895,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda): i = 0 next_shared_tensordict = shared_tensordict.get("next") root_shared_tensordict = shared_tensordict.exclude("next") - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): - raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" - ) + # TODO: restore this + # if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + # raise RuntimeError( + # "tensordict must be placed in shared memory (share_memory_() or memmap_())" + # ) shared_tensordict = shared_tensordict.clone(False).unlock_() initialized = True @@ -2130,7 +2131,7 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(cur_td) + child_pipe.send(cur_td.consolidate()) del cur_td elif cmd == "step": @@ -2142,7 +2143,7 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(next_td) + child_pipe.send(next_td.consolidate()) del next_td elif cmd == "step_and_maybe_reset":