Skip to content

Commit

Permalink
[Doc] Solve ref issues in docstrings
Browse files Browse the repository at this point in the history
ghstack-source-id: 09823fa85a94115291e7434478776fb0834f9b39
Pull Request resolved: #2776

(cherry picked from commit f5445a4)
  • Loading branch information
vmoens committed Feb 12, 2025
1 parent dfde953 commit 6fea6c4
Show file tree
Hide file tree
Showing 63 changed files with 415 additions and 274 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ jobs:
build-docs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.4"]
python_version: [ "3.9" ]
cuda_arch_version: [ "12.4" ]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.g5.4xlarge.nvidia.gpu
repository: pytorch/rl
upload-artifact: docs
timeout: 120
Expand All @@ -38,7 +39,6 @@ jobs:
set -v
# apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
yum makecache
# yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut
# Install Mesa and OpenGL Libraries:
yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL
# Install DRI Drivers:
Expand Down Expand Up @@ -112,7 +112,7 @@ jobs:
cd ./docs
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build -v -j 4
cd ..
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
Expand All @@ -123,8 +123,8 @@ jobs:
upload:
needs: build-docs
if: github.repository == 'pytorch/rl' && github.event_name == 'push' &&
((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag')
if: github.repository == 'pytorch/rl' && github.event_name == 'push' &&
((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag')
permissions:
contents: write
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ vmas
onnxscript
onnxruntime
onnx
psutil
4 changes: 4 additions & 0 deletions docs/reset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
24 changes: 20 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
import pytorch_sphinx_theme
import torchrl

# Suppress warnings - TODO
# suppress_warnings = [ 'misc.highlighting_failure' ]
# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)

project = "torchrl"
Expand Down Expand Up @@ -86,6 +85,21 @@
"torchvision": ("https://pytorch.org/vision/stable/", None),
}


def kill_procs(gallery_conf, fname):
import os

import psutil

# Get the current process
current_proc = psutil.Process(os.getpid())
# Iterate over all child processes
for child in current_proc.children(recursive=True):
# Kill the child process
child.terminate()
print(f"Killed child process with PID {child.pid}") # noqa: T201


sphinx_gallery_conf = {
"examples_dirs": "reference/generated/tutorials/", # path to your example scripts
"gallery_dirs": "tutorials", # path to where to save gallery generated output
Expand All @@ -95,9 +109,11 @@
"notebook_images": "reference/generated/tutorials/media/", # images to parse
"download_all_examples": True,
"abort_on_example_error": True,
"show_memory": True,
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
# "show_memory": True,
# "capture_repr": ("_repr_html_", "__repr__"), # capture representations
"write_computation_times": True,
# "compress_images": ("images", "thumbnails"),
"reset_modules": (kill_procs, "matplotlib", "seaborn"),
}

napoleon_use_ivar = True
Expand Down
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ Intermediate
tutorials/pretrained_models
tutorials/dqn_with_rnn
tutorials/rb_tutorial
tutorials/export

Advanced
--------
Expand Down
3 changes: 2 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ component (sub-environments or agents) should be reset.
This allows to reset some but not all of the components.

The ``"_reset"`` key has two distinct functionalities:

1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
not be present in the input tensordict. TorchRL's convention is that the
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
Expand Down Expand Up @@ -899,7 +900,7 @@ to be able to create this other composition:
Hash
InitTracker
KLRewardTransform
LineariseReward
LineariseRewards
NoopResetEnv
ObservationNorm
ObservationTransform
Expand Down
4 changes: 2 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None):
"""Resets the setters in setter_dict.
``setter_dict`` is a copy of implementations. We just need to iterate through its
values and call :meth:`~.module_set` for each.
values and call :meth:`module_set` for each.
"""
if VERBOSE:
Expand Down Expand Up @@ -888,7 +888,7 @@ def _standardize(
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution.
"""
if eps is None:
Expand Down
9 changes: 7 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,12 @@ class SyncDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -1462,6 +1464,7 @@ class _MultiDataCollector(DataCollectorBase):
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such:
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Expand Down Expand Up @@ -1548,7 +1551,7 @@ class _MultiDataCollector(DataCollectorBase):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
Expand Down Expand Up @@ -2774,10 +2777,12 @@ class aSyncDataCollector(MultiaSyncDataCollector):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -2863,7 +2868,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,12 @@ class DistributedDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -341,7 +343,7 @@ class DistributedDataCollector(DataCollectorBase):
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ class RayCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ class RPCDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -190,7 +192,7 @@ class RPCDataCollector(DataCollectorBase):
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ class DistributedSyncDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -222,7 +224,7 @@ class DistributedSyncDataCollector(DataCollectorBase):
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def preprocess(
Args and Keyword Args are forwarded to :meth:`~tensordict.TensorDictBase.map`.
The dataset can subsequently be deleted using :meth:`~.delete`.
The dataset can subsequently be deleted using :meth:`delete`.
Keyword Args:
dest (path or equivalent): a path to the location of the new dataset.
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class for more information on how to interact with non-tensor data
sampling strategy.
If the ``batch_size`` is ``None`` (default), iterating over the
dataset will deliver trajectories one at a time *whereas* calling
:meth:`~.sample` will *still* require a batch-size to be provided.
:meth:`sample` will *still* require a batch-size to be provided.
Keyword Args:
shuffle (bool, optional): if ``True``, trajectories are delivered in a
Expand Down Expand Up @@ -115,7 +115,7 @@ class for more information on how to interact with non-tensor data
replacement (bool, optional): if ``False``, sampling will be done
without replacement. Defaults to ``True`` for downloaded datasets,
``False`` for streamed datasets.
pad (bool, float or None): if ``True``, trajectories of insufficient length
pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length
given the `slice_len` or `num_slices` arguments will be padded with
0s. If another value is provided, it will be used for padding. If
``False`` or ``None`` (default) any encounter with a trajectory of
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_tensordict_pair(
in the storage. Defaults to ``None`` (all keys are registered).
max_size (int, optional): the maximum number of elements in the storage. Ignored if the
``storage_constructor`` is passed. Defaults to ``1000``.
storage_constructor (type, optional): a type of tensor storage.
storage_constructor (Type, optional): a type of tensor storage.
Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`.
Other options include :class:`~tensordict.nn.storage.FixedStorage`.
hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`.
Expand Down
Loading

0 comments on commit 6fea6c4

Please sign in to comment.