From b0ea50d547ef54b736bdf4202a6618e62afea754 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Jun 2024 08:57:55 +0100 Subject: [PATCH 1/4] init --- torchrl/data/replay_buffers/samplers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b23650d5e52..822fbc59bc9 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1068,7 +1068,9 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory, at_capacity=storage._is_full + trajectory=trajectory, + at_capacity=storage._is_full, + cursor=getattr(storage, "_last_cursor", None), ) if self.cache_values: self._cache["stop-and-length"] = vals From 15f1f0d489af84c0a2b186c24d62e80fe2fed6d1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Jun 2024 09:16:45 +0100 Subject: [PATCH 2/4] amend --- torchrl/data/replay_buffers/samplers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 822fbc59bc9..bcacbbcbbc7 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1806,6 +1806,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811 ) def _preceding_stop_idx(self, storage, lengths, seq_length): + print('lengths', lengths) preceding_stop_idx = self._cache.get("preceding_stop_idx") if preceding_stop_idx is not None: return preceding_stop_idx @@ -1841,6 +1842,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] seq_length, num_slices = self._adjusted_batch_size(batch_size) preceding_stop_idx = self._preceding_stop_idx(storage, lengths, seq_length) + preceding_stop_idx = (preceding_stop_idx + start_idx[0, 0]) % storage._len_along_dim0 if storage.ndim > 1: # we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted) # This is because the lengths come as they would for a permuted storage From 7c0eaa40887a31d04148f7e19b11819e2d605a5a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Jun 2024 10:27:00 +0100 Subject: [PATCH 3/4] amend --- test/test_rb.py | 17 +++++++++-------- torchrl/data/replay_buffers/samplers.py | 25 +++++++++++++++++++------ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index eb7c80cf0e6..6313390ffbe 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2230,12 +2230,12 @@ def test_slice_sampler( def test_slice_sampler_at_capacity(self, sampler): torch.manual_seed(0) - trajectory0 = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3]) - trajectory1 = torch.arange(2).repeat_interleave(5) + trajectory0 = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) + trajectory1 = torch.arange(2).repeat_interleave(6) trajectory = torch.stack([trajectory0, trajectory1], 0) td = TensorDict( - {"trajectory": trajectory, "steps": torch.arange(10).expand(2, 10)}, [2, 10] + {"trajectory": trajectory, "steps": torch.arange(12).expand(2, 12)}, [2, 12] ) rb = ReplayBuffer( @@ -2469,7 +2469,8 @@ def test_slice_sampler_strictlength(self): @pytest.mark.parametrize("ndim", [1, 2]) @pytest.mark.parametrize("strict_length", [True, False]) @pytest.mark.parametrize("circ", [False, True]) - def test_slice_sampler_prioritized(self, ndim, strict_length, circ): + @pytest.mark.parametrize("at_capacity", [False, True]) + def test_slice_sampler_prioritized(self, ndim, strict_length, circ, at_capacity): torch.manual_seed(0) out = [] for t in range(5): @@ -2491,9 +2492,9 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ): if ndim == 2: data = torch.stack([data, data]) rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(data.numel(), ndim=ndim), + storage=LazyTensorStorage(data.numel() - at_capacity, ndim=ndim), sampler=PrioritizedSliceSampler( - max_capacity=data.numel(), + max_capacity=data.numel() - at_capacity, alpha=1.0, beta=1.0, end_key="done", @@ -2530,8 +2531,8 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ): assert (samples["traj"] == 0).any() # Check that all samples of the first traj contain all elements (since it's too short to fullfill 10 elts) sc = samples[samples["traj"] == 0]["step_count"] - assert (sc == 0).sum() == (sc == 1).sum() - assert (sc == 0).sum() == (sc == 4).sum() + assert (sc == 1).sum() == (sc == 2).sum() + assert (sc == 1).sum() == (sc == 4).sum() assert rb._sampler._cache rb.extend(data) assert not rb._sampler._cache diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index bcacbbcbbc7..d9b05051e27 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1805,8 +1805,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811 .flip(0) ) - def _preceding_stop_idx(self, storage, lengths, seq_length): - print('lengths', lengths) + def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx): preceding_stop_idx = self._cache.get("preceding_stop_idx") if preceding_stop_idx is not None: return preceding_stop_idx @@ -1831,6 +1830,13 @@ def _preceding_stop_idx(self, storage, lengths, seq_length): # Mask the rightmost values of that padded tensor preceding_stop_idx = pad[:, -seq_length + 1 + span_right :] preceding_stop_idx = preceding_stop_idx[preceding_stop_idx >= 0] + if storage._is_full: + preceding_stop_idx = ( + preceding_stop_idx + + np.ravel_multi_index( + tuple(start_idx[0].tolist()), storage._total_shape + ) + ) % storage._total_shape.numel() if self.cache_values: self._cache["preceding_stop_idx"] = preceding_stop_idx return preceding_stop_idx @@ -1841,8 +1847,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] start_idx, stop_idx, lengths = self._get_stop_and_length(storage) seq_length, num_slices = self._adjusted_batch_size(batch_size) - preceding_stop_idx = self._preceding_stop_idx(storage, lengths, seq_length) - preceding_stop_idx = (preceding_stop_idx + start_idx[0, 0]) % storage._len_along_dim0 + preceding_stop_idx = self._preceding_stop_idx( + storage, lengths, seq_length, start_idx + ) if storage.ndim > 1: # we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted) # This is because the lengths come as they would for a permuted storage @@ -1870,7 +1877,13 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] # Find the stop that comes after the start index # say start_tensor has shape [N, X] and stop_idx has shape [M, X] # diff will have shape [M, N, X] - diff = stop_idx.unsqueeze(1) - starts_tensor.unsqueeze(0) + stop_idx_corr = stop_idx.clone() + stop_idx_corr[:, 0] = torch.where( + stop_idx[:, 0] < start_idx[:, 0], + stop_idx[:, 0] + storage._len_along_dim0, + stop_idx[:, 0], + ) + diff = stop_idx_corr.unsqueeze(1) - starts_tensor.unsqueeze(0) # filter out all items that don't belong to the same dim in the storage mask = (diff[:, :, 1:] != 0).any(-1) diff = diff[:, :, 0] @@ -1880,7 +1893,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] diff[diff < 0] = diff.max() + 1 # Take the arg min along dim 0 (thereby reducing dim M) idx = diff.argmin(dim=0) - stops = stop_idx[idx, 0] + stops = stop_idx_corr[idx, 0] # TODO: here things may not work bc we could have spanning trajs, # though I cannot show that it breaks in the tests if starts_tensor.ndim > 1: From 3b32f20401ba27d2ed09d91ac9a2ce90b91abcb4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Jun 2024 10:39:59 +0100 Subject: [PATCH 4/4] amend --- torchrl/data/replay_buffers/samplers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index d9b05051e27..1d1499312b8 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1862,12 +1862,14 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] ) # force to not sample index at the end of a trajectory + vals = torch.tensor(self._sum_tree[preceding_stop_idx.cpu().numpy()]) self._sum_tree[preceding_stop_idx.cpu().numpy()] = 0.0 # and no need to update self._min_tree starts, info = PrioritizedSampler.sample( self, storage=storage, batch_size=batch_size // seq_length ) + self._sum_tree[preceding_stop_idx.cpu().numpy()] = vals # We must truncate the seq_length if (1) not strict length or (2) span[1] if self.span[1] or not self.strict_length: if not isinstance(starts, torch.Tensor):