Skip to content

[BUG] Unexpected behavior of SumSegmentTree Resulting in Invalid Slices in PrioritizedSliceSampler.sample() #2230

Closed
@wertyuilife2

Description

@wertyuilife2

In PrioritizedSampler.sample(), _sum_tree.scan_lower_bound() sometimes generates index greater than len(storage).

The test code is as follows:

import torch
from torchrl._torchrl import SumSegmentTreeFp32
import numpy as np

def test_sum_tree():
    torch.manual_seed(0)
    np.random.seed(0)
    sum_tree = SumSegmentTreeFp32(500)
    
    # repeat to ensure the bug happens
    for _ in range(1000):
        # update priority
        index = torch.arange(0,100, dtype=torch.long, device=torch.device("cpu"))
        priority = torch.rand(100, device=torch.device("cpu"))
        sum_tree[index] = priority+1e-8 # 1e-8 are not essential, w/o 1e-8, bug still happens
        
        # sample
        p_sum = sum_tree.query(0, 100)
        mass = np.random.uniform(0.0, p_sum, size=1000000) # sample a lot to ensure the bug happens.
        scanned_index = sum_tree.scan_lower_bound(mass)
        if scanned_index.max()>=100:
            print("Unexpected index! p_sum:{}, mass.max():{}, scanned_index.max():{}".format(p_sum, mass.max(),scanned_index.max()))
        
test_sum_tree()

In PrioritizedSampler.sample(), this unexpected behavior is handled by index.clamp_max_(len(storage) - 1), but it still causes unexpected behavior in PrioritizedSliceSampler.sample().

This occurs because, although the sample at index = len(storage) - 1 is covered by preceding_stop_idx, it can still be sampled as start index, leading to cross-trajectory sampling.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions