Skip to content

Commit bd5c17e

Browse files
author
Vincent Moens
committed
init
1 parent 35df59e commit bd5c17e

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

torchrl/envs/batched_envs.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,7 +1392,7 @@ def _step_and_maybe_reset_no_buffers(
13921392
self, tensordict: TensorDictBase
13931393
) -> Tuple[TensorDictBase, TensorDictBase]:
13941394

1395-
for i, _data in enumerate(tensordict.unbind(0)):
1395+
for i, _data in enumerate(tensordict.consolidate().unbind(0)):
13961396
self.parent_channels[i].send(("step_and_maybe_reset", _data))
13971397

13981398
results = [None] * self.num_workers
@@ -1489,7 +1489,7 @@ def step_and_maybe_reset(
14891489
def _step_no_buffers(
14901490
self, tensordict: TensorDictBase
14911491
) -> Tuple[TensorDictBase, TensorDictBase]:
1492-
for i, data in enumerate(tensordict.unbind(0)):
1492+
for i, data in enumerate(tensordict.consolidate().unbind(0)):
14931493
self.parent_channels[i].send(("step", data))
14941494
out_tds = []
14951495
for i, channel in enumerate(self.parent_channels):
@@ -1576,7 +1576,7 @@ def _reset_no_buffers(
15761576
needs_resetting,
15771577
) -> Tuple[TensorDictBase, TensorDictBase]:
15781578
tdunbound = (
1579-
tensordict.unbind(0)
1579+
tensordict.consolidate().unbind(0)
15801580
if is_tensor_collection(tensordict)
15811581
else [None] * self.num_workers
15821582
)
@@ -1895,10 +1895,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
18951895
i = 0
18961896
next_shared_tensordict = shared_tensordict.get("next")
18971897
root_shared_tensordict = shared_tensordict.exclude("next")
1898-
if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
1899-
raise RuntimeError(
1900-
"tensordict must be placed in shared memory (share_memory_() or memmap_())"
1901-
)
1898+
# TODO: restore this
1899+
# if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
1900+
# raise RuntimeError(
1901+
# "tensordict must be placed in shared memory (share_memory_() or memmap_())"
1902+
# )
19021903
shared_tensordict = shared_tensordict.clone(False).unlock_()
19031904

19041905
initialized = True
@@ -2130,7 +2131,7 @@ def _run_worker_pipe_direct(
21302131
event.record()
21312132
event.synchronize()
21322133
mp_event.set()
2133-
child_pipe.send(cur_td)
2134+
child_pipe.send(cur_td.consolidate())
21342135
del cur_td
21352136

21362137
elif cmd == "step":
@@ -2142,7 +2143,7 @@ def _run_worker_pipe_direct(
21422143
event.record()
21432144
event.synchronize()
21442145
mp_event.set()
2145-
child_pipe.send(next_td)
2146+
child_pipe.send(next_td.consolidate())
21462147
del next_td
21472148

21482149
elif cmd == "step_and_maybe_reset":

0 commit comments

Comments
 (0)