Description
Motivation
This issue comes from the original issue #2205.
My work requires modifying the contents of the buffer. Specifically, I need to sample an item, modify it, and put it back in the buffer. However, torchrl currently does not seem to encourage modifying buffer contents. When calling buffer._storage.set(index, data)
to put my modified data back into the buffer, it implicitly changes _storage._len
, which can cause the sampler to sample empty samples. The following code demonstrates this issue:
import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from tensordict import TensorDict
def test_sampler():
torch.manual_seed(0)
sampler = SliceSampler(
num_slices=2,
traj_key="trajectory",
strict_length=True,
)
trajectory = torch.tensor([4, 4, 1, 2, 2, 2, 3, 3, 3, 4])
td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
rb = ReplayBuffer(
sampler=sampler,
storage=LazyTensorStorage(20),
batch_size=6,
)
rb.extend(td)
for i in range(10):
data, info = rb.sample(return_info=True)
print("[loop {}]sampled trajectory: {}".format(i, data["trajectory"]))
# I want to modify data and put it back
# data = modify_func(data)
rb._storage.set(info["index"], data)
# The len(storage) increases due to rb._storage.set(),
# causing sampling of undefined data(trajectory 0) in the future loop.
print("[loop {}]len(storage): {}".format(i, len(rb._storage)))
test_sampler()
I resolved this by directly modifying buffer._storage._storage
while holding the buffer._replay_lock
. It took me two days to discover that TensorStorage.set()
implicitly changes _len
. I believe this method should behave more intuitively. I am not sure if other Storage
classes have similar issues, but TensorStorage
definitely does.
Solution
Provide a method that can modify ReplayBuffer in place, like Replaybuffer.set(index, data).
Additional context
See discussion in the original issue #2205.
Checklist
- I have checked that there is no similar issue in the repo (required)