Skip to content

[BUG] Unintended Cross-Trajectory Sampling in PrioritizedSliceSampler.sample() #2208

@wertyuilife2

Description

@wertyuilife2

Describe the bug

This issue comes from the original issue #2205.

As per the comments, the preceding_stop_idx variable in PrioritizedSliceSampler.sample() attempts to build a list of indexes that we don't want to sample: all the steps at a seq_length distance from the end of the trajectory, with the end of the trajectory (stop_idx) included. However, it does not do this correctly.

To Reproduce

The following code demonstrates this issue:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        # preceding_stop_idx in sample(): [5, 4, 8, 7], which should be [5, 4, 8, 7, 9, 0, 1, 2] or
        # [5, 4, 8, 7, 0, 1, 2], depending whether you want to ignore the spanning trajectories.
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))

test_sampler()

This causes PrioritizedSliceSampler.sample to sample across trajectories, which is not the expected behavior, unlike SliceSampler which handles this correctly.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions