Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 7, 2024
1 parent 332499a commit fae8ce7
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 5 deletions.
19 changes: 19 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,25 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span):
else:
assert found_traj_0

def test_prb_update_max_priority(self):
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
)
for data in torch.arange(20):
idx = rb.add(data)
rb.update_priority(idx, 21 - data)
if data <= 9:
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0
else:
assert rb._sampler._max_priority[0] == 11
assert rb._sampler._max_priority[1] == 0
idx = rb.extend(torch.arange(10))
rb.update_priority(idx, 12)
assert rb._sampler._max_priority[0] == 12
assert rb._sampler._max_priority[1] == 0


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
61 changes: 56 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,60 @@ def _init(self):
raise NotImplementedError(
f"dtype {self.dtype} not supported by PrioritizedSampler"
)
self._max_priority = 1.0
self._max_priority = None

def _empty(self):
self._init()

@property
def _max_priority(self):
max_priority_index = self.__dict__.get("_max_priority")
if max_priority_index is None:
return (None, None)
return max_priority_index

@_max_priority.setter
def _max_priority(self, value):
self.__dict__["_max_priority"] = value

def _maybe_erase_max_priority(self, index):
max_priority_index = self._max_priority[1]
if max_priority_index is None:
return

def check_index(index=index, max_priority_index=max_priority_index):
if isinstance(index, torch.Tensor):
# index can be 1d or 2d
if index.ndim == 1:
is_overwritten = (index == max_priority_index).any()
else:
is_overwritten = (index == max_priority_index).all(-1).any()
elif isinstance(index, int):
is_overwritten = index == max_priority_index
elif isinstance(index, slice):
# This won't work if called recursively
is_overwritten = max_priority_index in range(
index.indices(self._max_capacity)
)
elif isinstance(index, tuple):
is_overwritten = isinstance(max_priority_index, tuple)
if is_overwritten:
for idx, mpi in zip(index, max_priority_index):
is_overwritten &= check_index(idx, mpi)
else:
raise TypeError(f"index of type {type(index)} is not recognized.")
return is_overwritten

is_overwritten = check_index()
if is_overwritten:
self._max_priority = None

@property
def default_priority(self) -> float:
return (self._max_priority + self._eps) ** self._alpha
mp = self._max_priority[0]
if mp is None:
mp = 1
return (mp + self._eps) ** self._alpha

def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
if len(storage) == 0:
Expand Down Expand Up @@ -422,11 +468,13 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:

return index, {"_weight": weight}

def add(self, index: int) -> None:
def add(self, index: torch.Tensor | int) -> None:
super().add(index)
self._maybe_erase_max_priority(index)

def extend(self, index: torch.Tensor) -> None:
def extend(self, index: torch.Tensor | tuple) -> None:
super().extend(index)
self._maybe_erase_max_priority(index)

@torch.no_grad()
def update_priority(
Expand Down Expand Up @@ -494,7 +542,10 @@ def update_priority(
if priority.ndim:
priority = priority[valid_index]

self._max_priority = priority.max().clamp_min(self._max_priority).item()
max_p, max_p_idx = priority.max(dim=0)
max_priority = self._max_priority[0]
if max_priority is None or max_p > max_priority:
self._max_priority = (max_p, max_p_idx)
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand Down

0 comments on commit fae8ce7

Please sign in to comment.