Description
Describe the bug
This issue comes from the original issue #2205.
As per the comments, the preceding_stop_idx
variable in PrioritizedSliceSampler.sample()
attempts to build a list of indexes that we don't want to sample: all the steps at a seq_length distance from the end of the trajectory, with the end of the trajectory (stop_idx) included.
However, it does not do this correctly.
To Reproduce
The following code demonstrates this issue:
import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict
def test_sampler():
torch.manual_seed(0)
sampler = PrioritizedSliceSampler(
max_capacity=20,
num_slices=2,
traj_key="trajectory",
strict_length=True,
alpha=1.0,
beta=1.0,
)
trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
rb = ReplayBuffer(
sampler=sampler,
storage=LazyTensorStorage(20, device=torch.device("cuda")),
batch_size=6,
)
rb.extend(td)
for i in range(10):
# preceding_stop_idx in sample(): [5, 4, 8, 7], which should be [5, 4, 8, 7, 9, 0, 1, 2] or
# [5, 4, 8, 7, 0, 1, 2], depending whether you want to ignore the spanning trajectories.
traj = rb.sample()["trajectory"]
print("[loop {}]sampled trajectory: {}".format(i, traj))
test_sampler()
This causes PrioritizedSliceSampler.sample
to sample across trajectories, which is not the expected behavior, unlike SliceSampler
which handles this correctly.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)