-
Notifications
You must be signed in to change notification settings - Fork 335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Unintended Cross-Trajectory Sampling in PrioritizedSliceSampler.sample() #2208
Comments
Should be solved by #2202 |
Works under #2202, both with strict_length=True and False |
@vmoens this issue has not been resolved cause I made a significant mistake in the test code (my bad)! In the original test code, the sizes of import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler, PrioritizedSliceSampler
from tensordict import TensorDict
def test_sampler():
torch.manual_seed(0)
sampler = PrioritizedSliceSampler(
max_capacity=10,
num_slices=2,
traj_key="trajectory",
strict_length=True,
alpha=1.0,
beta=1.0,
)
# sampler = SliceSampler(
# num_slices=2,
# traj_key='trajectory',
# strict_length=True,
# span=True
# )
trajectory = torch.tensor([3, 0, 1, 1, 1, 2, 2, 2, 3, 3])
td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
rb = ReplayBuffer(
sampler=sampler,
storage=LazyTensorStorage(10, device=torch.device("cuda")),
batch_size=6,
)
rb.extend(td)
for i in range(10):
# preceding_stop_idx in sample(): [1 2 3 5 6 8 9]
traj = rb.sample()["trajectory"]
print("[loop {}]sampled trajectory: {}".format(i, traj))
test_sampler() with new test code, two issues arise here (tested on torchrl-nightly==2024.6.13):
|
#2228 will fix this, but be mindful that I don't think that in this example we can really sample traj 3. import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler, PrioritizedSliceSampler
from tensordict import TensorDict
def test_sampler():
torch.manual_seed(0)
sampler = PrioritizedSliceSampler(
max_capacity=10,
num_slices=2,
traj_key="trajectory",
# end_key="done",
strict_length=True,
alpha=1.0,
beta=1.0,
)
trajectory0 = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
done0 = torch.tensor([False, True, False, False, True, False, False, True, False, False])
td0 = TensorDict({"trajectory": trajectory0, "steps": torch.arange(10), "done": done0}, [10])
trajectory1 = torch.tensor([3])
done1 = torch.tensor([False])
td1 = TensorDict({"trajectory": trajectory1, "steps": torch.tensor([10]), "done": done1}, [1])
rb = ReplayBuffer(
sampler=sampler,
storage=LazyTensorStorage(10, device=torch.device("cpu")),
batch_size=6,
)
rb.extend(td0)
rb.extend(td1)
for i in range(10):
# preceding_stop_idx in sample(): [1 2 3 5 6 8 9]
s, info = rb.sample(return_info=True)
traj = s["trajectory"]
print("[loop {}] sampled trajectory: {}".format(i, traj))
print("[loop {}] index {}".format(i, info["index"]))
assert len(traj.unique())<=2
test_sampler() |
Oh I'm okay with that. I think missing a single trajectory has little impact on most RL algorithms. |
Describe the bug
This issue comes from the original issue #2205.
As per the comments, the
preceding_stop_idx
variable inPrioritizedSliceSampler.sample()
attempts tobuild a list of indexes that we don't want to sample: all the steps at a seq_length distance from the end of the trajectory, with the end of the trajectory (stop_idx) included.
However, it does not do this correctly.To Reproduce
The following code demonstrates this issue:
This causes
PrioritizedSliceSampler.sample
to sample across trajectories, which is not the expected behavior, unlikeSliceSampler
which handles this correctly.Checklist
The text was updated successfully, but these errors were encountered: