Closed
Description
In PrioritizedSampler.sample()
, _sum_tree.scan_lower_bound()
sometimes generates index greater than len(storage).
The test code is as follows:
import torch
from torchrl._torchrl import SumSegmentTreeFp32
import numpy as np
def test_sum_tree():
torch.manual_seed(0)
np.random.seed(0)
sum_tree = SumSegmentTreeFp32(500)
# repeat to ensure the bug happens
for _ in range(1000):
# update priority
index = torch.arange(0,100, dtype=torch.long, device=torch.device("cpu"))
priority = torch.rand(100, device=torch.device("cpu"))
sum_tree[index] = priority+1e-8 # 1e-8 are not essential, w/o 1e-8, bug still happens
# sample
p_sum = sum_tree.query(0, 100)
mass = np.random.uniform(0.0, p_sum, size=1000000) # sample a lot to ensure the bug happens.
scanned_index = sum_tree.scan_lower_bound(mass)
if scanned_index.max()>=100:
print("Unexpected index! p_sum:{}, mass.max():{}, scanned_index.max():{}".format(p_sum, mass.max(),scanned_index.max()))
test_sum_tree()
In PrioritizedSampler.sample()
, this unexpected behavior is handled by index.clamp_max_(len(storage) - 1)
, but it still causes unexpected behavior in PrioritizedSliceSampler.sample()
.
This occurs because, although the sample at index = len(storage) - 1
is covered by preceding_stop_idx
, it can still be sampled as start index, leading to cross-trajectory sampling.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)