Skip to content

Commit

Permalink
[Doc] Fix formatting errors
Browse files Browse the repository at this point in the history
ghstack-source-id: ac1f3da66c1374d3d19fed88e80f8ed5407b3459
Pull Request resolved: #2786

(cherry picked from commit 03d6586)
  • Loading branch information
vmoens committed Feb 18, 2025
1 parent 882dc79 commit e42ffd3
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 31 deletions.
2 changes: 1 addition & 1 deletion torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The AtariDQN dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/atari`.
`~/.cache/torchrl/atari`.atari`.
num_procs (int, optional): number of processes to launch for preprocessing.
Has no effect whenever the data is already downloaded. Defaults to 0
(no multiprocessing used).
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class D4RLExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The D4RL dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/d4rl`.
`~/.cache/torchrl/atari`.d4rl`.
download (bool, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``.
**env_kwargs (key-value pairs): additional kwargs for
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/gen_dgrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GenDGRLExperienceReplay(BaseDatasetExperienceReplay):
dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/gen_dgrl`.
`~/.cache/torchrl/atari`.gen_dgrl`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The Minari dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/minari`.
`~/.cache/torchrl/atari`.minari`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class for more information on how to interact with non-tensor data
root (Path or str, optional): The OpenX dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/openx`.
`~/.cache/torchrl/atari`.openx`.
streaming (bool, optional): if ``True``, the data won't be downloaded but
read from a stream instead.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/roboset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RobosetExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The Roboset dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/roboset`.
`~/.cache/torchrl/atari`.roboset`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/vd4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class VD4RLExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The V-D4RL dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/vd4rl`.
`~/.cache/torchrl/atari`.vd4rl`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
45 changes: 23 additions & 22 deletions torchrl/data/map/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,33 @@ class QueryModule(TensorDictModuleBase):
If a single ``hash_module`` is provided but no aggregator is passed, it will take
the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is
provided but no aggregator is passed, it will default to ``SipHash``.
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
returned. This can be used to retrieve the integer index within the storage,
corresponding to a given input tensordict. This can be overridden at runtime by
providing the ``clone`` argument to the forward method.
Defaults to ``False``.
d
Examples:
>>> query_module = QueryModule(
... in_keys=["key1", "key2"],
... index_key="index",
... hash_module=SipHash(),
... )
>>> query = TensorDict(
... {
... "key1": torch.Tensor([[1], [1], [1], [2]]),
... "key2": torch.Tensor([[3], [3], [2], [3]]),
... "other": torch.randn(4),
... },
... batch_size=(4,),
... )
>>> res = query_module(query)
>>> # The first two pairs of key1 and key2 match
>>> assert res["index"][0] == res["index"][1]
>>> # The last three pairs of key1 and key2 have at least one mismatching value
>>> assert res["index"][1] != res["index"][2]
>>> assert res["index"][2] != res["index"][3]
Examples:
>>> query_module = QueryModule(
... in_keys=["key1", "key2"],
... index_key="index",
... hash_module=SipHash(),
... )
>>> query = TensorDict(
... {
... "key1": torch.Tensor([[1], [1], [1], [2]]),
... "key2": torch.Tensor([[3], [3], [2], [3]]),
... "other": torch.randn(4),
... },
... batch_size=(4,),
... )
>>> res = query_module(query)
>>> # The first two pairs of key1 and key2 match
>>> assert res["index"][0] == res["index"][1]
>>> # The last three pairs of key1 and key2 have at least one mismatching value
>>> assert res["index"][1] != res["index"][2]
>>> assert res["index"][2] != res["index"][3]
"""

def __init__(
Expand Down
16 changes: 14 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ReplayBuffer:
is used with PyTree structures (see example below).
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
Expand All @@ -108,6 +109,7 @@ class ReplayBuffer:
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
dim_extend (int, optional): indicates the dim to consider for
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
When using ``dim_extend > 0``, we recommend using the ``ndim``
Expand All @@ -128,6 +130,7 @@ class ReplayBuffer:
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
Expand Down Expand Up @@ -582,6 +585,7 @@ def register_save_hook(self, hook: Callable[[Any], Any]):
.. note:: Hooks are currently not serialized when saving a replay buffer: they must
be manually re-initialized every time the buffer is created.
"""
self._storage.register_save_hook(hook)

Expand Down Expand Up @@ -926,15 +930,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
.. note:: The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
be preferred whenever the batch-size is consistent across the
experiment. If the batch-size is likely to change, it can be
passed to the :meth:`sample` method. This option is
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
dim_extend (int, optional): indicates the dim to consider for
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
When using ``dim_extend > 0``, we recommend using the ``ndim``
Expand Down Expand Up @@ -1051,6 +1056,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
Expand All @@ -1060,6 +1066,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
Expand All @@ -1085,6 +1092,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
Expand Down Expand Up @@ -1394,6 +1402,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
Expand All @@ -1403,6 +1412,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
Expand Down Expand Up @@ -1431,6 +1441,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
Expand Down Expand Up @@ -1669,6 +1680,7 @@ class ReplayBufferEnsemble(ReplayBuffer):
Defaults to ``None`` (global default generator).
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class RolloutFromModel:
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False)
"""

EOS_TOKEN_ID = 50256
Expand Down

0 comments on commit e42ffd3

Please sign in to comment.