Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 24, 2024
1 parent 38943aa commit 5fffe1e
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,9 +1392,7 @@ def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:

td = tensordict.consolidate(
share_memory=True, inplace=True, num_threads=1
)
td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
for i in range(td.shape[0]):
# We send the same td multiple times as it is in shared mem and we just need to index it
# in each process.
Expand Down Expand Up @@ -1496,12 +1494,11 @@ def step_and_maybe_reset(
def _step_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
for i, data in enumerate(
tensordict.consolidate(
share_memory=True, inplace=True, num_threads=1
).unbind(0)
):
self.parent_channels[i].send(("step", data))
data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
for i, local_data in enumerate(data.unbind(0)):
self.parent_channels[i].send(("step", local_data))
# for i in range(data.shape[0]):
# self.parent_channels[i].send(("step", (data, i)))
out_tds = []
for i, channel in enumerate(self.parent_channels):
self._events[i].wait()
Expand Down Expand Up @@ -1586,17 +1583,24 @@ def _reset_no_buffers(
reset_kwargs_list,
needs_resetting,
) -> Tuple[TensorDictBase, TensorDictBase]:
tdunbound = (
tensordict.consolidate(share_memory=True, num_threads=1).unbind(0)
if is_tensor_collection(tensordict)
else [None] * self.num_workers
)
if is_tensor_collection(tensordict):
# tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
tensordict = tensordict.consolidate(
share_memory=True, num_threads=1
).unbind(0)
else:
tensordict = [None] * self.num_workers
out_tds = [None] * self.num_workers
for i, (data, reset_kwargs) in enumerate(zip(tdunbound, reset_kwargs_list)):
for i, (local_data, reset_kwargs) in enumerate(
zip(tensordict, reset_kwargs_list)
):
if not needs_resetting[i]:
out_tds[i] = tdunbound[i].exclude(*self.reset_keys)
localtd = local_data
if localtd is not None:
localtd = localtd.exclude(*self.reset_keys)
out_tds[i] = localtd
continue
self.parent_channels[i].send(("reset", (data, reset_kwargs)))
self.parent_channels[i].send(("reset", (local_data, reset_kwargs)))

for i, channel in enumerate(self.parent_channels):
if not needs_resetting[i]:
Expand Down Expand Up @@ -2129,6 +2133,8 @@ def _run_worker_pipe_direct(
raise RuntimeError("call 'init' before resetting")
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
# data, idx, reset_kwargs = data
# data = data[idx]
data, reset_kwargs = data
if data is not None:
data._fast_apply(
Expand All @@ -2151,6 +2157,8 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# data, idx = data
# data = data[idx]
next_td = env._step(data)
if event is not None:
event.record()
Expand All @@ -2165,8 +2173,8 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
data, idx = data
data = data[idx]
# data, idx = data
# data = data[idx]
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
Expand Down

0 comments on commit 5fffe1e

Please sign in to comment.