Skip to content

[Feature Request] A Method to Modify ReplayBuffer In Place #2209

Closed
@wertyuilife2

Description

@wertyuilife2

Motivation

This issue comes from the original issue #2205.

My work requires modifying the contents of the buffer. Specifically, I need to sample an item, modify it, and put it back in the buffer. However, torchrl currently does not seem to encourage modifying buffer contents. When calling buffer._storage.set(index, data) to put my modified data back into the buffer, it implicitly changes _storage._len, which can cause the sampler to sample empty samples. The following code demonstrates this issue:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = SliceSampler(
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
    )
    trajectory = torch.tensor([4, 4, 1, 2, 2, 2, 3, 3, 3, 4])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20),
        batch_size=6,
    )
    rb.extend(td)

    for i in range(10):
        data, info = rb.sample(return_info=True)
        print("[loop {}]sampled trajectory: {}".format(i, data["trajectory"]))

        # I want to modify data and put it back
        # data = modify_func(data)
        rb._storage.set(info["index"], data)

        # The len(storage) increases due to rb._storage.set(),
        # causing sampling of undefined data(trajectory 0) in the future loop.
        print("[loop {}]len(storage): {}".format(i, len(rb._storage)))

test_sampler()

I resolved this by directly modifying buffer._storage._storage while holding the buffer._replay_lock. It took me two days to discover that TensorStorage.set() implicitly changes _len. I believe this method should behave more intuitively. I am not sure if other Storage classes have similar issues, but TensorStorage definitely does.

Solution

Provide a method that can modify ReplayBuffer in place, like Replaybuffer.set(index, data).

Additional context

See discussion in the original issue #2205.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions