Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Fix formatting errors #2786

Open
wants to merge 1 commit into
base: gh/vmoens/90/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The AtariDQN dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/atari`.
`~/.cache/torchrl/atari`.atari`.
num_procs (int, optional): number of processes to launch for preprocessing.
Has no effect whenever the data is already downloaded. Defaults to 0
(no multiprocessing used).
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class D4RLExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The D4RL dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/d4rl`.
`~/.cache/torchrl/atari`.d4rl`.
download (bool, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``.
**env_kwargs (key-value pairs): additional kwargs for
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/gen_dgrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GenDGRLExperienceReplay(BaseDatasetExperienceReplay):
dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/gen_dgrl`.
`~/.cache/torchrl/atari`.gen_dgrl`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The Minari dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/minari`.
`~/.cache/torchrl/atari`.minari`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class for more information on how to interact with non-tensor data
root (Path or str, optional): The OpenX dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/openx`.
`~/.cache/torchrl/atari`.openx`.
streaming (bool, optional): if ``True``, the data won't be downloaded but
read from a stream instead.

Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/roboset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RobosetExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The Roboset dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/roboset`.
`~/.cache/torchrl/atari`.roboset`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/vd4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class VD4RLExperienceReplay(BaseDatasetExperienceReplay):
root (Path or str, optional): The V-D4RL dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/vd4rl`.
`~/.cache/torchrl/atari`.vd4rl`.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
Expand Down
5 changes: 3 additions & 2 deletions torchrl/data/map/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ class QueryModule(TensorDictModuleBase):
If a single ``hash_module`` is provided but no aggregator is passed, it will take
the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is
provided but no aggregator is passed, it will default to ``SipHash``.
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
returned. This can be used to retrieve the integer index within the storage,
corresponding to a given input tensordict. This can be overridden at runtime by
providing the ``clone`` argument to the forward method.
Defaults to ``False``.
d

Examples:
>>> query_module = QueryModule(
... in_keys=["key1", "key2"],
Expand All @@ -106,6 +106,7 @@ class QueryModule(TensorDictModuleBase):
>>> # The last three pairs of key1 and key2 have at least one mismatching value
>>> assert res["index"][1] != res["index"][2]
>>> assert res["index"][2] != res["index"][3]

"""

def __init__(
Expand Down
16 changes: 14 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ReplayBuffer:
is used with PyTree structures (see example below).
batch_size (int, optional): the batch size to be used when sample() is
called.

.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
Expand All @@ -108,6 +109,7 @@ class ReplayBuffer:
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.

dim_extend (int, optional): indicates the dim to consider for
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
When using ``dim_extend > 0``, we recommend using the ``ndim``
Expand All @@ -128,6 +130,7 @@ class ReplayBuffer:
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)

generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
Expand Down Expand Up @@ -582,6 +585,7 @@ def register_save_hook(self, hook: Callable[[Any], Any]):

.. note:: Hooks are currently not serialized when saving a replay buffer: they must
be manually re-initialized every time the buffer is created.

"""
self._storage.register_save_hook(hook)

Expand Down Expand Up @@ -926,15 +930,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the

.. note:: The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
be preferred whenever the batch-size is consistent across the
experiment. If the batch-size is likely to change, it can be
passed to the :meth:`sample` method. This option is
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.

dim_extend (int, optional): indicates the dim to consider for
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
When using ``dim_extend > 0``, we recommend using the ``ndim``
Expand Down Expand Up @@ -1051,6 +1056,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.

.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
Expand All @@ -1060,6 +1066,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.

priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
Expand All @@ -1085,6 +1092,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)

generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
Expand Down Expand Up @@ -1394,6 +1402,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.

.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
Expand All @@ -1403,6 +1412,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.

priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
Expand Down Expand Up @@ -1431,6 +1441,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)

generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
Expand Down Expand Up @@ -1669,6 +1680,7 @@ class ReplayBufferEnsemble(ReplayBuffer):
Defaults to ``None`` (global default generator).

.. warning:: As of now, the generator has no effect on the transforms.

shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.

Expand Down
1 change: 1 addition & 0 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class RolloutFromModel:
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False)

"""

EOS_TOKEN_ID = 50256
Expand Down
Loading