Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 14, 2024
1 parent 35df59e commit bd5c17e
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:

for i, _data in enumerate(tensordict.unbind(0)):
for i, _data in enumerate(tensordict.consolidate().unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", _data))

results = [None] * self.num_workers
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def step_and_maybe_reset(
def _step_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
for i, data in enumerate(tensordict.unbind(0)):
for i, data in enumerate(tensordict.consolidate().unbind(0)):
self.parent_channels[i].send(("step", data))
out_tds = []
for i, channel in enumerate(self.parent_channels):
Expand Down Expand Up @@ -1576,7 +1576,7 @@ def _reset_no_buffers(
needs_resetting,
) -> Tuple[TensorDictBase, TensorDictBase]:
tdunbound = (
tensordict.unbind(0)
tensordict.consolidate().unbind(0)
if is_tensor_collection(tensordict)
else [None] * self.num_workers
)
Expand Down Expand Up @@ -1895,10 +1895,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
i = 0
next_shared_tensordict = shared_tensordict.get("next")
root_shared_tensordict = shared_tensordict.exclude("next")
if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
raise RuntimeError(
"tensordict must be placed in shared memory (share_memory_() or memmap_())"
)
# TODO: restore this
# if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
# raise RuntimeError(
# "tensordict must be placed in shared memory (share_memory_() or memmap_())"
# )
shared_tensordict = shared_tensordict.clone(False).unlock_()

initialized = True
Expand Down Expand Up @@ -2130,7 +2131,7 @@ def _run_worker_pipe_direct(
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(cur_td)
child_pipe.send(cur_td.consolidate())
del cur_td

elif cmd == "step":
Expand All @@ -2142,7 +2143,7 @@ def _run_worker_pipe_direct(
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(next_td)
child_pipe.send(next_td.consolidate())
del next_td

elif cmd == "step_and_maybe_reset":
Expand Down

0 comments on commit bd5c17e

Please sign in to comment.