@@ -1068,7 +1068,9 @@ def _get_stop_and_length(self, storage, fallback=True):
1068
1068
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
1069
1069
)
1070
1070
vals = self ._find_start_stop_traj (
1071
- trajectory = trajectory , at_capacity = storage ._is_full
1071
+ trajectory = trajectory ,
1072
+ at_capacity = storage ._is_full ,
1073
+ cursor = getattr (storage , "_last_cursor" , None ),
1072
1074
)
1073
1075
if self .cache_values :
1074
1076
self ._cache ["stop-and-length" ] = vals
@@ -1803,7 +1805,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
1803
1805
.flip (0 )
1804
1806
)
1805
1807
1806
- def _preceding_stop_idx (self , storage , lengths , seq_length ):
1808
+ def _preceding_stop_idx (self , storage , lengths , seq_length , start_idx ):
1807
1809
preceding_stop_idx = self ._cache .get ("preceding_stop_idx" )
1808
1810
if preceding_stop_idx is not None :
1809
1811
return preceding_stop_idx
@@ -1828,6 +1830,13 @@ def _preceding_stop_idx(self, storage, lengths, seq_length):
1828
1830
# Mask the rightmost values of that padded tensor
1829
1831
preceding_stop_idx = pad [:, - seq_length + 1 + span_right :]
1830
1832
preceding_stop_idx = preceding_stop_idx [preceding_stop_idx >= 0 ]
1833
+ if storage ._is_full :
1834
+ preceding_stop_idx = (
1835
+ preceding_stop_idx
1836
+ + np .ravel_multi_index (
1837
+ tuple (start_idx [0 ].tolist ()), storage ._total_shape
1838
+ )
1839
+ ) % storage ._total_shape .numel ()
1831
1840
if self .cache_values :
1832
1841
self ._cache ["preceding_stop_idx" ] = preceding_stop_idx
1833
1842
return preceding_stop_idx
@@ -1838,7 +1847,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
1838
1847
start_idx , stop_idx , lengths = self ._get_stop_and_length (storage )
1839
1848
seq_length , num_slices = self ._adjusted_batch_size (batch_size )
1840
1849
1841
- preceding_stop_idx = self ._preceding_stop_idx (storage , lengths , seq_length )
1850
+ preceding_stop_idx = self ._preceding_stop_idx (
1851
+ storage , lengths , seq_length , start_idx
1852
+ )
1842
1853
if storage .ndim > 1 :
1843
1854
# we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
1844
1855
# This is because the lengths come as they would for a permuted storage
@@ -1851,12 +1862,14 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
1851
1862
)
1852
1863
1853
1864
# force to not sample index at the end of a trajectory
1865
+ vals = torch .tensor (self ._sum_tree [preceding_stop_idx .cpu ().numpy ()])
1854
1866
self ._sum_tree [preceding_stop_idx .cpu ().numpy ()] = 0.0
1855
1867
# and no need to update self._min_tree
1856
1868
1857
1869
starts , info = PrioritizedSampler .sample (
1858
1870
self , storage = storage , batch_size = batch_size // seq_length
1859
1871
)
1872
+ self ._sum_tree [preceding_stop_idx .cpu ().numpy ()] = vals
1860
1873
# We must truncate the seq_length if (1) not strict length or (2) span[1]
1861
1874
if self .span [1 ] or not self .strict_length :
1862
1875
if not isinstance (starts , torch .Tensor ):
@@ -1866,7 +1879,13 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
1866
1879
# Find the stop that comes after the start index
1867
1880
# say start_tensor has shape [N, X] and stop_idx has shape [M, X]
1868
1881
# diff will have shape [M, N, X]
1869
- diff = stop_idx .unsqueeze (1 ) - starts_tensor .unsqueeze (0 )
1882
+ stop_idx_corr = stop_idx .clone ()
1883
+ stop_idx_corr [:, 0 ] = torch .where (
1884
+ stop_idx [:, 0 ] < start_idx [:, 0 ],
1885
+ stop_idx [:, 0 ] + storage ._len_along_dim0 ,
1886
+ stop_idx [:, 0 ],
1887
+ )
1888
+ diff = stop_idx_corr .unsqueeze (1 ) - starts_tensor .unsqueeze (0 )
1870
1889
# filter out all items that don't belong to the same dim in the storage
1871
1890
mask = (diff [:, :, 1 :] != 0 ).any (- 1 )
1872
1891
diff = diff [:, :, 0 ]
@@ -1876,7 +1895,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
1876
1895
diff [diff < 0 ] = diff .max () + 1
1877
1896
# Take the arg min along dim 0 (thereby reducing dim M)
1878
1897
idx = diff .argmin (dim = 0 )
1879
- stops = stop_idx [idx , 0 ]
1898
+ stops = stop_idx_corr [idx , 0 ]
1880
1899
# TODO: here things may not work bc we could have spanning trajs,
1881
1900
# though I cannot show that it breaks in the tests
1882
1901
if starts_tensor .ndim > 1 :
0 commit comments