Skip to content

Commit 35df59e

Browse files
author
Vincent Moens
authored
[BugFix] Fix sliced PRB when only traj is provided (#2228)
1 parent ce92e35 commit 35df59e

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

test/test_rb.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2230,12 +2230,12 @@ def test_slice_sampler(
22302230
def test_slice_sampler_at_capacity(self, sampler):
22312231
torch.manual_seed(0)
22322232

2233-
trajectory0 = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
2234-
trajectory1 = torch.arange(2).repeat_interleave(5)
2233+
trajectory0 = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
2234+
trajectory1 = torch.arange(2).repeat_interleave(6)
22352235
trajectory = torch.stack([trajectory0, trajectory1], 0)
22362236

22372237
td = TensorDict(
2238-
{"trajectory": trajectory, "steps": torch.arange(10).expand(2, 10)}, [2, 10]
2238+
{"trajectory": trajectory, "steps": torch.arange(12).expand(2, 12)}, [2, 12]
22392239
)
22402240

22412241
rb = ReplayBuffer(
@@ -2469,7 +2469,8 @@ def test_slice_sampler_strictlength(self):
24692469
@pytest.mark.parametrize("ndim", [1, 2])
24702470
@pytest.mark.parametrize("strict_length", [True, False])
24712471
@pytest.mark.parametrize("circ", [False, True])
2472-
def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
2472+
@pytest.mark.parametrize("at_capacity", [False, True])
2473+
def test_slice_sampler_prioritized(self, ndim, strict_length, circ, at_capacity):
24732474
torch.manual_seed(0)
24742475
out = []
24752476
for t in range(5):
@@ -2491,9 +2492,9 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
24912492
if ndim == 2:
24922493
data = torch.stack([data, data])
24932494
rb = TensorDictReplayBuffer(
2494-
storage=LazyTensorStorage(data.numel(), ndim=ndim),
2495+
storage=LazyTensorStorage(data.numel() - at_capacity, ndim=ndim),
24952496
sampler=PrioritizedSliceSampler(
2496-
max_capacity=data.numel(),
2497+
max_capacity=data.numel() - at_capacity,
24972498
alpha=1.0,
24982499
beta=1.0,
24992500
end_key="done",
@@ -2530,8 +2531,8 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
25302531
assert (samples["traj"] == 0).any()
25312532
# Check that all samples of the first traj contain all elements (since it's too short to fullfill 10 elts)
25322533
sc = samples[samples["traj"] == 0]["step_count"]
2533-
assert (sc == 0).sum() == (sc == 1).sum()
2534-
assert (sc == 0).sum() == (sc == 4).sum()
2534+
assert (sc == 1).sum() == (sc == 2).sum()
2535+
assert (sc == 1).sum() == (sc == 4).sum()
25352536
assert rb._sampler._cache
25362537
rb.extend(data)
25372538
assert not rb._sampler._cache

torchrl/data/replay_buffers/samplers.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,9 @@ def _get_stop_and_length(self, storage, fallback=True):
10681068
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
10691069
)
10701070
vals = self._find_start_stop_traj(
1071-
trajectory=trajectory, at_capacity=storage._is_full
1071+
trajectory=trajectory,
1072+
at_capacity=storage._is_full,
1073+
cursor=getattr(storage, "_last_cursor", None),
10721074
)
10731075
if self.cache_values:
10741076
self._cache["stop-and-length"] = vals
@@ -1803,7 +1805,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
18031805
.flip(0)
18041806
)
18051807

1806-
def _preceding_stop_idx(self, storage, lengths, seq_length):
1808+
def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
18071809
preceding_stop_idx = self._cache.get("preceding_stop_idx")
18081810
if preceding_stop_idx is not None:
18091811
return preceding_stop_idx
@@ -1828,6 +1830,13 @@ def _preceding_stop_idx(self, storage, lengths, seq_length):
18281830
# Mask the rightmost values of that padded tensor
18291831
preceding_stop_idx = pad[:, -seq_length + 1 + span_right :]
18301832
preceding_stop_idx = preceding_stop_idx[preceding_stop_idx >= 0]
1833+
if storage._is_full:
1834+
preceding_stop_idx = (
1835+
preceding_stop_idx
1836+
+ np.ravel_multi_index(
1837+
tuple(start_idx[0].tolist()), storage._total_shape
1838+
)
1839+
) % storage._total_shape.numel()
18311840
if self.cache_values:
18321841
self._cache["preceding_stop_idx"] = preceding_stop_idx
18331842
return preceding_stop_idx
@@ -1838,7 +1847,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
18381847
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
18391848
seq_length, num_slices = self._adjusted_batch_size(batch_size)
18401849

1841-
preceding_stop_idx = self._preceding_stop_idx(storage, lengths, seq_length)
1850+
preceding_stop_idx = self._preceding_stop_idx(
1851+
storage, lengths, seq_length, start_idx
1852+
)
18421853
if storage.ndim > 1:
18431854
# we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
18441855
# This is because the lengths come as they would for a permuted storage
@@ -1851,12 +1862,14 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
18511862
)
18521863

18531864
# force to not sample index at the end of a trajectory
1865+
vals = torch.tensor(self._sum_tree[preceding_stop_idx.cpu().numpy()])
18541866
self._sum_tree[preceding_stop_idx.cpu().numpy()] = 0.0
18551867
# and no need to update self._min_tree
18561868

18571869
starts, info = PrioritizedSampler.sample(
18581870
self, storage=storage, batch_size=batch_size // seq_length
18591871
)
1872+
self._sum_tree[preceding_stop_idx.cpu().numpy()] = vals
18601873
# We must truncate the seq_length if (1) not strict length or (2) span[1]
18611874
if self.span[1] or not self.strict_length:
18621875
if not isinstance(starts, torch.Tensor):
@@ -1866,7 +1879,13 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
18661879
# Find the stop that comes after the start index
18671880
# say start_tensor has shape [N, X] and stop_idx has shape [M, X]
18681881
# diff will have shape [M, N, X]
1869-
diff = stop_idx.unsqueeze(1) - starts_tensor.unsqueeze(0)
1882+
stop_idx_corr = stop_idx.clone()
1883+
stop_idx_corr[:, 0] = torch.where(
1884+
stop_idx[:, 0] < start_idx[:, 0],
1885+
stop_idx[:, 0] + storage._len_along_dim0,
1886+
stop_idx[:, 0],
1887+
)
1888+
diff = stop_idx_corr.unsqueeze(1) - starts_tensor.unsqueeze(0)
18701889
# filter out all items that don't belong to the same dim in the storage
18711890
mask = (diff[:, :, 1:] != 0).any(-1)
18721891
diff = diff[:, :, 0]
@@ -1876,7 +1895,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
18761895
diff[diff < 0] = diff.max() + 1
18771896
# Take the arg min along dim 0 (thereby reducing dim M)
18781897
idx = diff.argmin(dim=0)
1879-
stops = stop_idx[idx, 0]
1898+
stops = stop_idx_corr[idx, 0]
18801899
# TODO: here things may not work bc we could have spanning trajs,
18811900
# though I cannot show that it breaks in the tests
18821901
if starts_tensor.ndim > 1:

0 commit comments

Comments
 (0)