@@ -1392,7 +1392,7 @@ def _step_and_maybe_reset_no_buffers(
1392
1392
self , tensordict : TensorDictBase
1393
1393
) -> Tuple [TensorDictBase , TensorDictBase ]:
1394
1394
1395
- for i , _data in enumerate (tensordict .unbind (0 )):
1395
+ for i , _data in enumerate (tensordict .consolidate (). unbind (0 )):
1396
1396
self .parent_channels [i ].send (("step_and_maybe_reset" , _data ))
1397
1397
1398
1398
results = [None ] * self .num_workers
@@ -1489,7 +1489,7 @@ def step_and_maybe_reset(
1489
1489
def _step_no_buffers (
1490
1490
self , tensordict : TensorDictBase
1491
1491
) -> Tuple [TensorDictBase , TensorDictBase ]:
1492
- for i , data in enumerate (tensordict .unbind (0 )):
1492
+ for i , data in enumerate (tensordict .consolidate (). unbind (0 )):
1493
1493
self .parent_channels [i ].send (("step" , data ))
1494
1494
out_tds = []
1495
1495
for i , channel in enumerate (self .parent_channels ):
@@ -1576,7 +1576,7 @@ def _reset_no_buffers(
1576
1576
needs_resetting ,
1577
1577
) -> Tuple [TensorDictBase , TensorDictBase ]:
1578
1578
tdunbound = (
1579
- tensordict .unbind (0 )
1579
+ tensordict .consolidate (). unbind (0 )
1580
1580
if is_tensor_collection (tensordict )
1581
1581
else [None ] * self .num_workers
1582
1582
)
@@ -1895,10 +1895,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
1895
1895
i = 0
1896
1896
next_shared_tensordict = shared_tensordict .get ("next" )
1897
1897
root_shared_tensordict = shared_tensordict .exclude ("next" )
1898
- if not (shared_tensordict .is_shared () or shared_tensordict .is_memmap ()):
1899
- raise RuntimeError (
1900
- "tensordict must be placed in shared memory (share_memory_() or memmap_())"
1901
- )
1898
+ # TODO: restore this
1899
+ # if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
1900
+ # raise RuntimeError(
1901
+ # "tensordict must be placed in shared memory (share_memory_() or memmap_())"
1902
+ # )
1902
1903
shared_tensordict = shared_tensordict .clone (False ).unlock_ ()
1903
1904
1904
1905
initialized = True
@@ -2130,7 +2131,7 @@ def _run_worker_pipe_direct(
2130
2131
event .record ()
2131
2132
event .synchronize ()
2132
2133
mp_event .set ()
2133
- child_pipe .send (cur_td )
2134
+ child_pipe .send (cur_td . consolidate () )
2134
2135
del cur_td
2135
2136
2136
2137
elif cmd == "step" :
@@ -2142,7 +2143,7 @@ def _run_worker_pipe_direct(
2142
2143
event .record ()
2143
2144
event .synchronize ()
2144
2145
mp_event .set ()
2145
- child_pipe .send (next_td )
2146
+ child_pipe .send (next_td . consolidate () )
2146
2147
del next_td
2147
2148
2148
2149
elif cmd == "step_and_maybe_reset" :
0 commit comments