diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst
index df7257762ec..eb4f6a69419 100644
--- a/docs/source/reference/collectors.rst
+++ b/docs/source/reference/collectors.rst
@@ -118,75 +118,49 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
Policy copy decision tree in Collectors.
Weight Synchronization in Distributed Environments
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+--------------------------------------------------
+
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
-Local and Remote Weight Updaters
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Sending and receiving model weights with WeightUpdaters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
-and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
+The weight synchronization process is facilitated by one dedicated extension point:
+:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
-- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
- the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
- different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
- It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
- situations where the server decides when to update the worker policies).
-- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
- remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
- the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
- devices or processes.
+:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to
+the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary.
+Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the
+weight synchronization with the policy.
+Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy
+state-dict (assuming it is a :class:`~torch.nn.Module` instance).
-Extending the Updater Classes
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Extending the Updater Class
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
+The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation
+untouched.
This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
-setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
+setups.
+By implementing the abstract methods in these base classes, users can define how weights are retrieved,
transformed, and applied, ensuring seamless integration with their existing infrastructure.
-Default Implementations
-~~~~~~~~~~~~~~~~~~~~~~~
-
-For common scenarios, the API provides default implementations of these updaters, such as
-:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
-:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
-:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
-These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
-distributed systems.
-
-Practical Considerations
-~~~~~~~~~~~~~~~~~~~~~~~~
-
-When designing a system that leverages this API, consider the following:
-
-- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
- implementation accounts for potential delays and optimizes data transfer where possible.
-- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
- the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
- suboptimal policy performance.
-- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
- overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
-
-By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
-scenarios, ensuring that their policies remain up-to-date and performant.
-
.. currentmodule:: torchrl.collectors
.. autosummary::
:toctree: generated/
:template: rl_template.rst
- WeightUpdateReceiverBase
- WeightUpdateSenderBase
- VanillaLocalWeightUpdater
- MultiProcessedRemoteWeightUpdate
- RayWeightUpdateSender
- DistributedWeightUpdateSender
- RPCWeightUpdateSender
+ WeightUpdaterBase
+ VanillaWeightUpdater
+ MultiProcessedWeightUpdater
+ RayWeightUpdater
+ DistributedWeightUpdater
+ RPCWeightUpdater
Collectors and replay buffers interoperability
----------------------------------------------
diff --git a/examples/collectors/mp_collector_mps.py b/examples/collectors/mp_collector_mps.py
index 0467e2bfd84..696adbe1540 100644
--- a/examples/collectors/mp_collector_mps.py
+++ b/examples/collectors/mp_collector_mps.py
@@ -45,12 +45,12 @@ class is necessary because MPS tensors cannot be sent over a pipe due to seriali
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule
from torch import nn
-from torchrl.collectors import MultiSyncDataCollector, WeightUpdateSenderBase
+from torchrl.collectors import MultiSyncDataCollector, WeightUpdaterBase
from torchrl.envs.libs.gym import GymEnv
-class MPSWeightUpdaterBase(WeightUpdateSenderBase):
+class MPSWeightUpdaterBase(WeightUpdaterBase):
def __init__(self, policy_weights, num_workers):
# Weights are on mps device, which cannot be shared
self.policy_weights = policy_weights.data
@@ -101,7 +101,7 @@ def policy_factory(device=device):
reset_at_each_iter=False,
device=device,
storing_device="cpu",
- weight_update_sender=MPSWeightUpdaterBase(policy_weights, 2),
+ weight_updater=MPSWeightUpdaterBase(policy_weights, 2),
# use_buffers=False,
# cat_results="stack",
)
diff --git a/test/test_collector.py b/test/test_collector.py
index 5d2ea9d1008..9b6856bdbf9 100644
--- a/test/test_collector.py
+++ b/test/test_collector.py
@@ -39,11 +39,7 @@
prod,
seed_generator,
)
-from torchrl.collectors import (
- aSyncDataCollector,
- SyncDataCollector,
- WeightUpdateSenderBase,
-)
+from torchrl.collectors import aSyncDataCollector, SyncDataCollector, WeightUpdaterBase
from torchrl.collectors.collectors import (
_Interruptor,
MultiaSyncDataCollector,
@@ -3489,7 +3485,7 @@ def __deepcopy_error__(*args, **kwargs):
class TestPolicyFactory:
- class MPSWeightUpdaterBase(WeightUpdateSenderBase):
+ class MPSWeightUpdaterBase(WeightUpdaterBase):
def __init__(self, policy_weights, num_workers):
# Weights are on mps device, which cannot be shared
self.policy_weights = policy_weights.data
@@ -3533,7 +3529,7 @@ def test_weight_update(self):
reset_at_each_iter=False,
device=device,
storing_device="cpu",
- weight_update_sender=self.MPSWeightUpdaterBase(policy_weights, 2),
+ weight_updater=self.MPSWeightUpdaterBase(policy_weights, 2),
)
collector.update_policy_weights_()
diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py
index 8e6c0d48fc5..a484c58a602 100644
--- a/torchrl/collectors/__init__.py
+++ b/torchrl/collectors/__init__.py
@@ -16,14 +16,12 @@
MultiProcessedWeightUpdate,
RayWeightUpdater,
VanillaWeightUpdater,
- WeightUpdateReceiverBase,
- WeightUpdateSenderBase,
+ WeightUpdaterBase,
)
__all__ = [
"RandomPolicy",
- "WeightUpdateReceiverBase",
- "WeightUpdateSenderBase",
+ "WeightUpdaterBase",
"VanillaWeightUpdater",
"RayWeightUpdater",
"MultiProcessedWeightUpdate",
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index 5485428a258..4d67f87c49a 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -53,8 +53,7 @@
from torchrl.collectors.weight_update import (
MultiProcessedWeightUpdate,
VanillaWeightUpdater,
- WeightUpdateReceiverBase,
- WeightUpdateSenderBase,
+ WeightUpdaterBase,
)
from torchrl.data import ReplayBuffer
from torchrl.data.tensor_specs import TensorSpec
@@ -155,41 +154,22 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
- _weight_update_receiver: WeightUpdateReceiverBase | None = None
- _weight_update_sender: WeightUpdateSenderBase | None = None
+ _weight_updater: WeightUpdaterBase | None = None
@property
- def weight_update_receiver(self) -> WeightUpdateReceiverBase:
- return self._weight_update_receiver
+ def weight_updater(self) -> WeightUpdaterBase:
+ return self._weight_updater
- @weight_update_receiver.setter
- def weight_update_receiver(
- self,
- value: WeightUpdateReceiverBase | Callable[[], WeightUpdateReceiverBase] | None,
- ):
+ @weight_updater.setter
+ def weight_updater(self, value: WeightUpdaterBase | None):
if value is not None:
- if not isinstance(value, WeightUpdateReceiverBase) and callable(value):
+ if not isinstance(value, WeightUpdaterBase) and callable(value):
# then it's a constructor
value = value()
value.register_collector(self)
if value.collector is not self:
raise RuntimeError("Failed to register collector.")
- self._weight_update_receiver = value
-
- @property
- def weight_update_sender(self) -> WeightUpdateSenderBase:
- return self._weight_update_sender
-
- @weight_update_sender.setter
- def weight_update_sender(self, value: WeightUpdateSenderBase | None):
- if value is not None:
- if not isinstance(value, WeightUpdateSenderBase) and callable(value):
- # then it's a constructor
- value = value()
- value.register_collector(self)
- if value.collector is not self:
- raise RuntimeError("Failed to register collector.")
- self._weight_update_sender = value
+ self._weight_updater = value
def _get_policy_and_device(
self,
@@ -299,30 +279,20 @@ def update_policy_weights_(
for the update. If not provided, the method will attempt to fetch the weights using the configured
weight updater.
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the
- workers that need to be updated. This is relevant when using a remote weights updater, which must
- be specified during the data collector's initialization. If `worker_ids` is provided without a
- configured remote weights updater, a TypeError will be raised.
+ workers that need to be updated. This is relevant when the collector has more than one worker associated
+ with it.
Raises:
- TypeError: If `worker_ids` is provided but no `weight_update_sender` is configured.
-
- .. note::
+ TypeError: If `worker_ids` is provided but no `weight_updater` is configured.
- - The method first attempts to update weights locally using `weight_update_receiver`, if available.
- - If a `weight_update_sender` is configured, it will be used to update the specified remote workers.
- - Users can extend the `WeightUpdateReceiverBase` and `WeightUpdateSenderBase` classes to customize
- the weight update logic for specific use cases. This method should not be overwritten.
+ .. note:: Users should extend the `WeightUpdaterBase` classes to customize
+ the weight update logic for specific use cases. This method should not be overwritten.
.. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
:meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
"""
- if self.weight_update_receiver is not None:
- self.weight_update_receiver(policy_weights, **kwargs)
- if self.weight_update_sender is not None:
- self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)
- elif worker_ids is not None:
- raise TypeError("worker_ids was passed but weight_update_sender was None.")
+ self.weight_updater(policy_weights, worker_ids=worker_ids, **kwargs)
def __iter__(self) -> Iterator[TensorDictBase]:
try:
@@ -537,12 +507,7 @@ class SyncDataCollector(DataCollectorBase):
or `ManiSkills `_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
- weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
- or its subclass, responsible for updating the policy weights on the local inference worker.
- If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
- which directly fetches and applies the weights from the server.
- Consider using a constructor if the updater needs to be serialized.
- weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
or its subclass, responsible for updating the policy weights on remote inference workers.
This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment.
Consider using a constructor if the updater needs to be serialized.
@@ -637,11 +602,8 @@ def __init__(
compile_policy: bool | dict[str, Any] | None = None,
cudagraph_policy: bool | dict[str, Any] | None = None,
no_cuda_sync: bool = False,
- weight_update_receiver: WeightUpdateReceiverBase
- | Callable[[], WeightUpdateReceiverBase]
- | None = None,
- weight_update_sender: WeightUpdateSenderBase
- | Callable[[], WeightUpdateSenderBase]
+ weight_updater: WeightUpdaterBase
+ | Callable[[], WeightUpdaterBase]
| None = None,
**kwargs,
):
@@ -893,13 +855,14 @@ def __init__(
self._frames = 0
self._iter = -1
- if weight_update_receiver is None:
- weight_update_receiver = VanillaWeightUpdater(
+ if weight_updater is None:
+ weight_updater = VanillaWeightUpdater(
weight_getter=self.get_weights_fn, policy_weights=self.policy_weights
)
+ elif not isinstance(weight_updater, WeightUpdaterBase):
+ raise TypeError("weight_updater must be a subclass of WeightUpdaterBase")
- self.weight_update_receiver = weight_update_receiver
- self.weight_update_sender = weight_update_sender
+ self.weight_updater = weight_updater
@property
def _traj_pool(self):
@@ -1698,10 +1661,6 @@ class _MultiDataCollector(DataCollectorBase):
:class:`~torchrl.collectors.MultiaSyncDataCollector`
or a derived class of these.
Defaults to :class:`~torchrl.collectors.SyncDataCollector`.
-
- .. note:: This keyword argument is particularly handy when local attributes need to be
- set, such as `weight_update_receiver`.
-
max_frames_per_traj (int, optional): Maximum steps per trajectory.
Note that a trajectory can span across multiple batches (unless
``reset_at_each_iter`` is set to ``True``, see below).
@@ -1735,7 +1694,7 @@ class _MultiDataCollector(DataCollectorBase):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
- update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
+ update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
@@ -1788,33 +1747,10 @@ class _MultiDataCollector(DataCollectorBase):
or `ManiSkills `_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
- weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
- or its subclass, responsible for updating the policy weights on the server worker.
- If not provided, left unused.
- Consider using a constructor if the updater needs to be serialized.
-
- .. note:: This instance or constructor is not passed to the workers. To specify the workers `weight_update_receiver`
- instance, you can pass a `collector_class` argument containing the constructor:
-
- >>> from functools import partial
- >>> # the weight receiver - called when `worker_collector.update_policy_weight_() is called
- >>> worker_weight_updater_receiver = ...
- >>> # The weight sender - called when `server_collector.update_policy_weight_()` is called
- >>> server_weight_updater_sender = ...
- >>> collector = MultiaSyncDataCollector(
- ... create_env_fn=[func, func],
- ... policy=policy,
- ... frames_per_batch=100,
- ... total_frames=1000,
- ... collector_class=partial(SyncDataCollector, weight_update_receiver=worker_weight_updater_receiver),
- ... weight_update_sender=server_weight_updater_sender,
- ... )
-
- weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
or its subclass, responsible for updating the policy weights on remote inference workers.
- If not provided, a :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate` will be used by default,
+ If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default,
which handles weight synchronization across multiple processes.
- See `weight_update_receiver` for details on the server / worker weight update API.
Consider using a constructor if the updater needs to be serialized.
"""
@@ -1857,11 +1793,8 @@ def __init__(
compile_policy: bool | dict[str, Any] | None = None,
cudagraph_policy: bool | dict[str, Any] | None = None,
no_cuda_sync: bool = False,
- weight_update_sender: WeightUpdateSenderBase
- | Callable[[], WeightUpdateReceiverBase]
- | None = None,
- weight_update_receiver: WeightUpdateReceiverBase
- | Callable[[], WeightUpdateReceiverBase]
+ weight_updater: WeightUpdaterBase
+ | Callable[[], WeightUpdaterBase]
| None = None,
):
self.closed = True
@@ -1956,21 +1889,20 @@ def __init__(
)
self._policy_weights_dict[policy_device] = weights
self._get_weights_fn = get_weights_fn
- if weight_update_sender is None:
- weight_update_sender = MultiProcessedWeightUpdate(
+ if weight_updater is None:
+ weight_updater = MultiProcessedWeightUpdate(
get_server_weights=self._get_weights_fn,
policy_weights=self._policy_weights_dict,
)
- elif weight_update_sender is None:
+ elif weight_updater is None:
warnings.warn(
- "weight_update_sender is None, but policy_factory is provided. This means that the server will "
+ "weight_updater is None, but policy_factory is provided. This means that the server will "
"not know how to send the weights to the workers. If the workers can handle their weight synchronization "
"on their own (via some specialized worker type / constructor) this may well work, but make sure "
"your weight synchronization strategy is properly set."
)
- self.weight_update_sender = weight_update_sender
- self.weight_update_receiver = weight_update_receiver
+ self.weight_updater = weight_updater
self.policy = policy
self.policy_factory = policy_factory
@@ -3137,7 +3069,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
- update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
+ update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py
index 9aff1e018d0..a04e56ee076 100644
--- a/torchrl/collectors/distributed/generic.py
+++ b/torchrl/collectors/distributed/generic.py
@@ -31,10 +31,7 @@
TCP_PORT,
)
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
-from torchrl.collectors.weight_update import (
- WeightUpdateReceiverBase,
- WeightUpdateSenderBase,
-)
+from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -181,7 +178,7 @@ def _run_collector(
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data.lock_()
else:
- if collector_kwargs.get("weight_update_sender") is None and (
+ if collector_kwargs.get("weight_updater") is None and (
policy_factory is None
or (isinstance(policy_factory, Sequence) and not any(policy_factory))
):
@@ -419,14 +416,9 @@ class DistributedDataCollector(DataCollectorBase):
to learn more.
Defaults to ``"submitit"``.
tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
- weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
- or its subclass, responsible for updating the policy weights on the local inference worker.
- This is typically not used in :class:`~torchrl.collectors.distributed.DistributedDataCollector` as it
- focuses on distributed environments.
- Consider using a constructor if the updater needs to be serialized.
- weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
or its subclass, responsible for updating the policy weights on distributed inference workers.
- If not provided, a :class:`~torchrl.collectors.distributed.DistributedWeightUpdateSender` will be used by
+ If not provided, a :class:`~torchrl.collectors.distributed.DistributedWeightUpdater` will be used by
default, which handles weight synchronization across distributed workers.
Consider using a constructor if the updater needs to be serialized.
@@ -464,11 +456,8 @@ def __init__(
max_weight_update_interval: int = -1,
launcher: str = "submitit",
tcp_port: int | None = None,
- weight_update_sender: WeightUpdateSenderBase
- | Callable[[], WeightUpdateSenderBase]
- | None = None,
- weight_update_receiver: WeightUpdateReceiverBase
- | Callable[[], WeightUpdateReceiverBase]
+ weight_updater: WeightUpdaterBase
+ | Callable[[], WeightUpdaterBase]
| None = None,
):
@@ -489,10 +478,9 @@ def __init__(
policy_weights = policy_weights.data.lock_()
elif any(policy_factory):
policy_weights = None
- if weight_update_sender is None:
+ if weight_updater is None:
raise RuntimeError(
- "weight_update_sender must be passed along with "
- "a policy_factory."
+ "weight_updater must be passed along with " "a policy_factory."
)
else:
if not any(policy_factory):
@@ -576,15 +564,14 @@ def __init__(
self._init_workers()
self._make_container()
- if weight_update_sender is None:
- weight_update_sender = DistributedWeightUpdater(
+ if weight_updater is None:
+ weight_updater = DistributedWeightUpdater(
store=self._store,
policy_weights=self.policy_weights,
num_workers=self.num_workers,
sync=self._sync,
)
- self.weight_update_sender = weight_update_sender
- self.weight_update_receiver = weight_update_receiver
+ self.weight_updater = weight_updater
@property
def device(self) -> list[torch.device]:
@@ -986,10 +973,10 @@ def shutdown(self):
torchrl_logger.info("collector shut down")
-class DistributedWeightUpdater(WeightUpdateSenderBase):
+class DistributedWeightUpdater(WeightUpdaterBase):
"""A remote weight updater for synchronizing policy weights across distributed workers.
- The `DistributedWeightUpdateSender` class provides a mechanism for updating the weights
+ The `DistributedWeightUpdater` class provides a mechanism for updating the weights
of a policy across distributed inference workers. It is designed to work with the
:class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights.
This class is typically used in distributed data collection scenarios where multiple workers
@@ -1014,12 +1001,12 @@ class DistributedWeightUpdater(WeightUpdateSenderBase):
.. note::
This class assumes that the server weights can be directly applied to the distributed workers
without any additional processing. If your use case requires more complex weight mapping or
- synchronization logic, consider extending `WeightUpdateSenderBase` with a custom implementation.
+ synchronization logic, consider extending `WeightUpdaterBase` with a custom implementation.
Raises:
RuntimeError: If the worker rank is less than 1 or if the status returned by the store is not "updated".
- .. seealso:: :class:`~torchrl.collectors.WeightUpdateSenderBase` and
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
:class:`~torchrl.collectors.distributed.DistributedDataCollector`.
"""
@@ -1053,7 +1040,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
def all_worker_ids(self) -> list[int] | list[torch.device]:
raise NotImplementedError
- def update_weights(
+ def push_weights(
self,
weights: TensorDictBase | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py
index 940fd6cd352..78a9317c8b8 100644
--- a/torchrl/collectors/distributed/ray.py
+++ b/torchrl/collectors/distributed/ray.py
@@ -22,11 +22,7 @@
SyncDataCollector,
)
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
-from torchrl.collectors.weight_update import (
- RayWeightUpdater,
- WeightUpdateReceiverBase,
- WeightUpdateSenderBase,
-)
+from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase
from torchrl.data import ReplayBuffer
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -277,14 +273,9 @@ class RayCollector(DataCollectorBase):
.. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
:class:`~torchrl.data.RayReplayBuffer` instance should be used here.
- weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
- or its subclass, responsible for updating the policy weights on the local inference worker.
- This is typically not used in :class:`~torchrl.collectors.RayCollector` as it focuses on distributed
- environments.
- Consider using a constructor if the updater needs to be serialized.
- weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
or its subclass, responsible for updating the policy weights on remote inference workers managed by Ray.
- If not provided, a :class:`~torchrl.collectors.RayWeightUpdateSender` will be used by default, leveraging
+ If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging
Ray's distributed capabilities.
Consider using a constructor if the updater needs to be serialized.
@@ -347,11 +338,8 @@ def __init__(
update_after_each_batch: bool = False,
max_weight_update_interval: int = -1,
replay_buffer: ReplayBuffer | None = None,
- weight_update_sender: WeightUpdateSenderBase
- | Callable[[], WeightUpdateSenderBase]
- | None = None,
- weight_update_receiver: WeightUpdateReceiverBase
- | Callable[[], WeightUpdateReceiverBase]
+ weight_updater: WeightUpdaterBase
+ | Callable[[], WeightUpdaterBase]
| None = None,
):
self.frames_per_batch = frames_per_batch
@@ -469,7 +457,7 @@ def check_list_length_consistency(*lists):
policy_weights = policy_weights.data.lock_()
else:
policy_weights = TensorDict(lock=True)
- if weight_update_sender is None:
+ if weight_updater is None:
warnings.warn(_NON_NN_POLICY_WEIGHTS)
self.policy_weights = policy_weights
self.collector_class = collector_class
@@ -537,14 +525,13 @@ def check_list_length_consistency(*lists):
collector_kwargs,
remote_configs,
)
- if weight_update_sender is None:
- weight_update_sender = RayWeightUpdater(
+ if weight_updater is None:
+ weight_updater = RayWeightUpdater(
policy_weights=policy_weights,
remote_collectors=self.remote_collectors,
max_interval=self.max_weight_update_interval,
)
- self.weight_update_sender = weight_update_sender
- self.weight_update_receiver = weight_update_receiver
+ self.weight_updater = weight_updater
# Print info of all remote workers
pending_samples = [
diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py
index 28c77daa56c..e56def1b1f8 100644
--- a/torchrl/collectors/distributed/rpc.py
+++ b/torchrl/collectors/distributed/rpc.py
@@ -36,10 +36,7 @@
TCP_PORT,
)
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
-from torchrl.collectors.weight_update import (
- WeightUpdateReceiverBase,
- WeightUpdateSenderBase,
-)
+from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -265,14 +262,9 @@ class RPCDataCollector(DataCollectorBase):
device used to pass data to main.
tensorpipe_options (dict, optional): a dictionary of keyword argument
to pass to :class:`torch.distributed.rpc.TensorPipeRpcBackendOption`.
- weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
- or its subclass, responsible for updating the policy weights on the local inference worker. This is
- typically not used in :class:`~torchrl.collectors.distrbibuted.RPCDataCollector` as it focuses on
- distributed environments.
- Consider using a constructor if the updater needs to be serialized.
- weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
or its subclass, responsible for updating the policy weights on remote inference workers using RPC.
- If not provided, an :class:`~torchrl.collectors.distributed.RPCWeightUpdateSender` will be used by default, which
+ If not provided, an :class:`~torchrl.collectors.distributed.RPCWeightUpdater` will be used by default, which
handles weight synchronization via RPC.
Consider using a constructor if the updater needs to be serialized.
@@ -311,11 +303,8 @@ def __init__(
tcp_port: str | None = None,
visible_devices: list[torch.device] | None = None,
tensorpipe_options: dict[str, Any] | None = None,
- weight_update_sender: WeightUpdateSenderBase
- | Callable[[], WeightUpdateSenderBase]
- | None = None,
- weight_update_receiver: WeightUpdateReceiverBase
- | Callable[[], WeightUpdateReceiverBase]
+ weight_updater: WeightUpdaterBase
+ | Callable[[], WeightUpdaterBase]
| None = None,
):
if collector_class == "async":
@@ -331,7 +320,7 @@ def __init__(
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data.lock_()
else:
- if weight_update_sender is None and (
+ if weight_updater is None and (
policy_factory is None
or (isinstance(policy_factory, Sequence) and not any(policy_factory))
):
@@ -422,16 +411,15 @@ def __init__(
tensorpipe_options
)
self._init()
- if weight_update_sender is None:
- weight_update_sender = RPCWeightUpdaterBase(
+ if weight_updater is None:
+ weight_updater = RPCWeightUpdaterBase(
collector_infos=self.collector_infos,
collector_class=self.collector_class,
collector_rrefs=self.collector_rrefs,
policy_weights=self.policy_weights,
num_workers=self.num_workers,
)
- self.weight_update_receiver = weight_update_receiver
- self.weight_update_sender = weight_update_sender
+ self.weight_updater = weight_updater
@property
def device(self) -> list[torch.device]:
@@ -822,10 +810,10 @@ def shutdown(self):
self._shutdown = True
-class RPCWeightUpdaterBase(WeightUpdateSenderBase):
+class RPCWeightUpdaterBase(WeightUpdaterBase):
"""A remote weight updater for synchronizing policy weights across remote workers using RPC.
- The `RPCWeightUpdateSender` class provides a mechanism for updating the weights of a policy
+ The `RPCWeightUpdater` class provides a mechanism for updating the weights of a policy
across remote inference workers using RPC. It is designed to work with the :class:`~torchrl.collectors.distributed.RPCDataCollector`
to ensure that each worker receives the latest policy weights.
This class is typically used in distributed data collection scenarios where remote workers
@@ -849,9 +837,9 @@ class RPCWeightUpdaterBase(WeightUpdateSenderBase):
.. note::
This class assumes that the server weights can be directly applied to the remote workers
without any additional processing. If your use case requires more complex weight mapping or
- synchronization logic, consider extending `WeightUpdateSenderBase` with a custom implementation.
+ synchronization logic, consider extending `WeightUpdaterBase` with a custom implementation.
- .. seealso:: :class:`~torchrl.collectors.WeightUpdateSenderBase` and
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
:class:`~torchrl.collectors.distributed.RPCDataCollector`.
"""
@@ -887,7 +875,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
def all_worker_ids(self) -> list[int] | list[torch.device]:
raise NotImplementedError
- def update_weights(
+ def push_weights(
self,
weights: TensorDictBase | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py
index d93b6b2b3dc..08133242974 100644
--- a/torchrl/collectors/distributed/sync.py
+++ b/torchrl/collectors/distributed/sync.py
@@ -80,7 +80,7 @@ def _distributed_init_collection_node(
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data.lock_()
else:
- if collector_kwargs.get("weight_update_sender") is None and (
+ if collector_kwargs.get("weight_updater") is None and (
policy_factory is None
or (isinstance(policy_factory, Sequence) and not any(policy_factory))
):
@@ -327,7 +327,7 @@ def __init__(
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data.lock_()
else:
- if collector_kwargs.get("weight_update_sender") is None and (
+ if collector_kwargs.get("weight_updater") is None and (
policy_factory is None
or (isinstance(policy_factory, Sequence) and not any(policy_factory))
):
diff --git a/torchrl/collectors/llm.py b/torchrl/collectors/llm.py
deleted file mode 100644
index accf3ea80a2..00000000000
--- a/torchrl/collectors/llm.py
+++ /dev/null
@@ -1,366 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-from __future__ import annotations
-
-from collections import deque
-from typing import Callable
-
-import torch
-
-from tensordict import lazy_stack, TensorDictBase
-
-from torchrl.collectors import (
- SyncDataCollector,
- WeightUpdateReceiverBase,
- WeightUpdateSenderBase,
-)
-from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
-from torchrl.envs import AsyncEnvPool
-from torchrl.envs.common import EnvBase
-
-
-class LLMCollector(SyncDataCollector):
- """A simplified version of SyncDataCollector for LLM inference.
-
- Args:
- env (EnvBase or EnvBase constructor): the environment to be used for data collection.
-
- Keyword Args:
- policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection.
- policy_factory (Callable[[], Callable], optional): a callable that returns
- a policy instance. This is exclusive with the `policy` argument.
-
- .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
-
- steps_per_batch (int): A keyword-only argument representing the total
- number of elements in a batch; -1 is never ending (until shutdown).
- total_steps (int): A keyword-only argument representing the total
- number of steps returned by the collector
- during its lifespan.
- yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
- (`yield_completed_trajectories=False`, default) or single, completed trajectories
- (`yield_completed_trajectories=True`).
- Defaults to `False` unless `yield_only_last_steps=True`, where it cannot be `False`.
-
- .. warning:: If the `done` state of the environment is not properly set, this may lead to a collector
- that never leads any data.
-
- yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the
- last (done) steps.
- If `True`, a single trajectory is yielded (or written in the buffer) at a time.
-
- .. warning:: If the `done` state of the environment is not properly set, this may lead to a collector
- that never leads any data.
-
- postproc (Callable, optional): A post-processing transform, such as
- a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
- instance.
- Defaults to ``None``.
- async_envs (bool, optional): if ``True``, the environment will be run asynchronously. Defaults to `True` if the
- environment is a :class:`~torchrl.envs.AsyncEnvPool` instance.
- replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
- but populate the buffer instead. Defaults to ``None``.
- reset_at_each_iter (bool, optional): if ``True``, the environment will be reset at each iteration.
- flatten_data (bool, optional): if ``True``, the collector will flatten the collected data
- before returning it. In practice, this means that if an environment of batch-size `(B,)` is used
- and run for `T` steps, `flatten_data=True` will present data of shape `(B*T,)`, whereas
- `flatten_data=False` will not present data of shape `(B, T)`.
- Defaults to `True` when `replay_buffer` is provided, `False` otherwise.
- weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
- or its subclass, responsible for updating the policy weights on the local inference worker.
- If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
- which directly fetches and applies the weights from the server.
- Consider using a constructor if the updater needs to be serialized.
- weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
- or its subclass, responsible for updating the policy weights on remote inference workers.
- This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment.
- Consider using a constructor if the updater needs to be serialized.
-
- Examples:
- >>> import vllm
- >>> from torchrl.modules import vLLMWrapper
- >>> from pytorch.rl.test.mocking_classes import DummyStrDataLoader
- >>> from torchrl.envs import LLMEnv
- >>> llm_model = vllm.LLM("gpt2")
- >>> tokenizer = llm_model.get_tokenizer()
- >>> tokenizer.pad_token = tokenizer.eos_token
- >>> policy = vLLMWrapper(llm_model)
- >>> dataloader = DummyStrDataLoader(1)
- >>> env = LLMEnv.from_dataloader(
- ... dataloader=dataloader,
- ... tokenizer=tokenizer,
- ... from_text=True,
- ... batch_size=1,
- ... group_repeats=True,
- ... )
- >>> collector = LLMCollector(
- ... env=env,
- ... policy_factory=lambda: policy,
- ... steps_per_batch=env.batch_size[0],
- ... total_steps=3,
- ... )
- >>> for i, data in enumerate(collector):
- ... if i == 2:
- ... print(data)
- ... break
- LazyStackedTensorDict(
- fields={
- attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False),
- collector: LazyStackedTensorDict(
- fields={
- traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)},
- exclusive_fields={
- },
- batch_size=torch.Size([1, 1]),
- device=None,
- is_shared=False,
- stack_dim=1),
- done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
- terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
- text: NonTensorStack(
- [['plsgqejeyd']],
- batch_size=torch.Size([1, 1]),
- device=None),
- text_response: NonTensorStack(
- [['ec.n.n.n.tjbjz3perwhz']],
- batch_size=torch.Size([1, 1]),
- device=None),
- tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False),
- tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)},
- exclusive_fields={
- },
- batch_size=torch.Size([1, 1]),
- device=None,
- is_shared=False,
- stack_dim=1)
- >>> del collector
-
- """
-
- def __init__(
- self,
- env: EnvBase | Callable[[], EnvBase],
- *,
- policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
- policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
- | None = None,
- steps_per_batch: int,
- yield_only_last_steps: bool | None = None,
- yield_completed_trajectories: bool | None = None,
- postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
- total_steps: int = -1,
- async_envs: bool | None = None,
- replay_buffer: ReplayBuffer | None = None,
- reset_at_each_iter: bool = False,
- flatten_data: bool | None = None,
- weight_update_receiver: WeightUpdateReceiverBase
- | Callable[[], WeightUpdateReceiverBase]
- | None = None,
- weight_update_sender: WeightUpdateSenderBase
- | Callable[[], WeightUpdateSenderBase]
- | None = None,
- ):
- super().__init__(
- create_env_fn=env,
- policy=policy,
- policy_factory=policy_factory,
- frames_per_batch=steps_per_batch,
- replay_buffer=replay_buffer,
- total_frames=total_steps,
- weight_update_receiver=weight_update_receiver,
- weight_update_sender=weight_update_sender,
- reset_at_each_iter=reset_at_each_iter,
- trust_policy=True,
- use_buffers=False,
- no_cuda_sync=True,
- extend_buffer=True,
- )
- if yield_only_last_steps is None:
- yield_only_last_steps = False
-
- if yield_completed_trajectories is None:
- yield_completed_trajectories = yield_only_last_steps
- elif yield_only_last_steps and not yield_completed_trajectories:
- raise TypeError(
- "yield_only_last_steps=True requires yield_completed_trajectories=True (or None)"
- )
-
- if yield_only_last_steps:
- if flatten_data is not None:
- raise TypeError(
- "`yield_only_last_steps` cannot be `True` when `flatten_data` is passed."
- )
- if self.reset_at_each_iter:
- raise TypeError(
- "`yield_only_last_steps` cannot be `True` when `reset_at_each_iter=True`."
- )
- if flatten_data is None:
- flatten_data = replay_buffer is not None
- self.flatten_data = flatten_data
- self.yield_completed_trajectories = yield_completed_trajectories
- self.yield_only_last_steps = yield_only_last_steps
- if self.yield_completed_trajectories:
- if len(self.env.batch_size) != 1:
- raise ValueError(
- "`yield_only_last_steps` only works with envs that have a single batch dimension. Got "
- f"env.batch_size={self.env.batch_size}."
- )
- self._yield_queues = [deque() for _ in range(self.env.batch_size[0])]
- self._trajectory_queue = deque()
- self.async_envs = bool(async_envs) | isinstance(self.env, AsyncEnvPool)
- if self.async_envs and not isinstance(self.env, AsyncEnvPool):
- # This basically means that `async_envs` is automatically set and passing is it useless as of today,
- # except for the following error.
- raise RuntimeError(
- "async_envs requires the environment to be an AsyncEnvPool instance."
- )
-
- @property
- def steps_per_batch(self) -> int:
- """Alias to `frames_per_batch`."""
- return self.frames_per_batch
-
- @property
- def rollout(self) -> Callable[[], TensorDictBase]:
- if self.yield_completed_trajectories:
- if self.async_envs:
- return self._rollout_yield_trajs_async
- else:
- return self._rollout_yield_trajs
- else:
- return self._rollout_all
-
- def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout
- if self.reset_at_each_iter or self._shuttle is None:
- data = self.env.reset()
- else:
- data = self._shuttle
-
- trajectory = []
- collected_steps = 0
- while collected_steps < self.steps_per_batch:
- policy_input = data
- env_input = self.policy(policy_input)
- env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
-
- # carry over collector data without messing up devices
- collector_data = env_output.get("collector").copy()
- env_next_output.set("collector", collector_data)
- self._shuttle = env_next_output
- self._update_traj_ids(env_output)
- data = env_output
- trajectory.append(data)
- collected_steps += data.numel()
- trajectory = lazy_stack(trajectory, -1)
- if self.flatten_data:
- return trajectory.view(-1)
- return trajectory
-
- def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rollout
- if self._shuttle is None:
- raise RuntimeError("Data shuttle not found")
- # next_output = self.env.reset()
- else:
- next_output = self._shuttle
-
- collected_steps = 0
- dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
- while True:
- if self._trajectory_queue:
- break
- env_input = self.policy(next_output)
- cur_output, next_output = self.env.step_and_maybe_reset(env_input)
- # for i in range(cur_output.numel()):
- # print(len(cur_output[i]["text"]) < len(cur_output[i]["next", "text"]))
-
- # carry over collector data without messing up devices
- self._update_traj_ids(cur_output)
-
- collector_data = cur_output.get("collector").copy()
- next_output.set("collector", collector_data)
-
- # if the loop is interrupted
- self._shuttle = next_output
- collected_steps += next_output.numel()
- for i, (_data, queue) in enumerate(
- zip(cur_output.unbind(0), self._yield_queues)
- ):
- queue.append(_data)
- dones[i] = _data["next", "done"].any()
- if dones.any():
- for idx in dones.nonzero()[0].tolist():
- if not self.yield_only_last_steps:
- self._trajectory_queue.append(
- lazy_stack(self._yield_queues[idx], -1)
- )
- else:
- # FIXME: We need to increment the step count here because iterator() won't
- # see the extra steps
- # We use lazy-stack because unsqueeze doesn't nest the strings in lists
- self._trajectory_queue.append(
- lazy_stack([self._yield_queues[idx][-1]])
- )
- self._yield_queues[idx].clear()
-
- result = self._trajectory_queue.popleft()
- return result
-
- started = False
-
- def _rollout_yield_trajs_async(
- self,
- ) -> TensorDictBase: # A simplified version of rollout
- if not self.started:
- next_output = self._shuttle
- env_input = self.policy(next_output)
- self.env.async_step_and_maybe_reset_send(env_input)
- self.started = True
-
- collected_steps = 0
- dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
- while True:
- if self._trajectory_queue:
- break
-
- cur_output, next_output = self.env.async_step_and_maybe_reset_recv()
-
- # Get the env ids
- env_ids = cur_output.get(self.env._env_idx_key).tolist()
-
- # carry over collector data without messing up devices
- self._update_traj_ids(cur_output)
-
- collector_data = cur_output.get("collector").copy()
- next_output.set("collector", collector_data)
-
- collected_steps += next_output.numel()
- dones.fill_(False)
- for i, _data in zip(env_ids, cur_output.unbind(0)):
- queue = self._yield_queues[i]
- queue.append(_data)
- dones[i] = _data["next", "done"].any()
- if dones.any():
- for idx in dones.nonzero()[0].tolist():
- if not self.yield_only_last_steps:
- self._trajectory_queue.append(
- lazy_stack(self._yield_queues[idx], -1)
- )
- else:
- # FIXME: We need to increment the step count here because iterator() won't
- # see the extra steps
- # We use lazy-stack because unsqueeze doesn't nest the strings in lists
- self._trajectory_queue.append(
- lazy_stack([self._yield_queues[idx][-1]])
- )
- self._yield_queues[idx].clear()
-
- # Launch the next batch:
- # FIXME: Add a condition RE number of frames here
- if True:
- env_input = self.policy(next_output)
- self.env.async_step_and_maybe_reset_send(env_input)
-
- result = self._trajectory_queue.popleft()
- return result
diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py
index 55d1586307f..36dc20f0858 100644
--- a/torchrl/collectors/utils.py
+++ b/torchrl/collectors/utils.py
@@ -13,7 +13,7 @@
_NON_NN_POLICY_WEIGHTS = (
"The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and "
- "update_policy_weights_ will be a no-op. Consider passing a local/weight_update_sender object "
+ "update_policy_weights_ will be a no-op. Consider passing a local/weight_updater object "
"to your collector to handle the weight updates."
)
diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py
index ea92cb37fb2..e0783b3c321 100644
--- a/torchrl/collectors/weight_update.py
+++ b/torchrl/collectors/weight_update.py
@@ -17,119 +17,42 @@
Policy = TypeVar("Policy", bound=TensorDictModuleBase)
-class WeightUpdateReceiverBase(metaclass=abc.ABCMeta):
- """A base class for updating local policy weights from a server.
-
- This class provides an interface for downloading and updating the weights of a policy
- on a local inference worker. The update process is decentralized, meaning the inference
- worker is responsible for fetching the weights from the server.
-
- To extend this class, implement the following abstract methods:
-
- - `_get_server_weights`: Define how to retrieve the weights from the server.
- - `_get_local_weights`: Define how to access the current local weights.
- - `_maybe_map_weights`: Optionally transform the server weights to match the local weights.
-
- Attributes:
- policy (Policy, optional): The policy whose weights are to be updated.
- get_weights_from_policy (Callable, optional): A function to extract weights from the policy.
- get_weights_from_server (Callable, optional): A function to fetch weights from the server.
- weight_map_fn (Callable, optional): A function to map server weights to local weights.
- cache_policy_weights (bool): Whether to cache the policy weights locally.
-
- Methods:
- update_weights: Updates the local weights with the server weights.
-
-
- .. seealso:: :class:`~torchrl.collectors.RemoteWeightsUpdaterBase` and
- :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
-
- """
-
- _collector_wr: Any = None
-
- def register_collector(self, collector: DataCollectorBase): # noqa
- """Register a collector in the updater.
-
- Once registered, the updater will not accept another collector.
-
- Args:
- collector (DataCollectorBase): The collector to register.
-
- """
- if self._collector_wr is not None:
- raise RuntimeError("Cannot register collector twice.")
- self._collector_wr = weakref.ref(collector)
-
- @property
- def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa
- return self._collector_wr() if self._collector_wr is not None else None
-
- @abstractmethod
- def _get_server_weights(self) -> TensorDictBase:
- ...
-
- @abstractmethod
- def _get_local_weights(self) -> TensorDictBase:
- ...
-
- @abstractmethod
- def _maybe_map_weights(
- self, server_weights: TensorDictBase, local_weights: TensorDictBase
- ) -> TensorDictBase:
- ...
-
- def _update_local_weights(
- self, local_weights: TensorDictBase, mapped_weights: TensorDictBase
- ) -> TensorDictBase:
- local_weights.update_(mapped_weights)
-
- def __call__(
- self,
- weights: TensorDictBase | None = None,
- ):
- return self.update_weights(weights=weights)
-
- def update_weights(self, weights: TensorDictBase | None = None) -> TensorDictBase:
- if weights is None:
- # get server weights (source)
- server_weights = self._get_server_weights()
- else:
- server_weights = weights
- # Get local weights
- local_weights = self._get_local_weights()
-
- # Optionally map the weights
- mapped_weights = self._maybe_map_weights(server_weights, local_weights)
+class WeightUpdaterBase(metaclass=abc.ABCMeta):
+ """A base class for updating remote policy weights on inference workers.
- # Update the weights
- self._update_local_weights(local_weights, mapped_weights)
+ The weight updater is the central piece of the weight update scheme:
+ - In leaf collector nodes, it is responsible for sending the weights to the policy, which can be as simple as
+ updating a state-dict, or more complex if an inference server is being used.
+ - In server collector nodes, it is responsible for sending the weights to the leaf collectors.
-class WeightUpdateSenderBase(metaclass=abc.ABCMeta):
- """A base class for updating remote policy weights on inference workers.
+ In a collector, the updater is called within :meth:`~torchrl.collector.DataCollectorBase.update_policy_weights_`.`
- This class provides an interface for uploading and synchronizing the weights of a policy
- across remote inference workers. The update process is centralized, meaning the server
- is responsible for distributing the weights to the inference nodes.
+ The main method of this class is the :meth:`~.push_weights` method, which updates the policy weights in the worker /
+ policy.
To extend this class, implement the following abstract methods:
+ - `_get_server_weights` (optional): Define how to retrieve the weights from the server if they are not passed to
+ the updater directly. This method is only called if the weights (hanlde) is not passed directly.
- `_sync_weights_with_worker`: Define how to synchronize weights with a specific worker.
- - `_get_server_weights`: Define how to retrieve the weights from the server.
+ This method must be implemented by child classes.
- `_maybe_map_weights`: Optionally transform the server weights before distribution.
+ By default, this method returns the weights unchanged.
- `all_worker_ids`: Provide a list of all worker identifiers.
+ Returns `None` by default (no worker id).
Attributes:
- policy (Policy, optional): The policy whose weights are to be updated.
+ collector: The collector (or any container) of the weight receiver. The collector is registered via
+ :meth:`~torchrl.collectors.WeightUpdateReceiverBase.register_collector`.
Methods:
- update_weights: Updates the weights on specified or all remote workers.
- register_collector: Registers a collector. This should be called automatically by the collector
- upon registration of the updater.
+ push_weights: Updates the weights on specified or all remote workers.
+ The `__call__` method is a proxy to `push_weights`.
+ register_collector: Registers the collector (or any container) in the receiver through a weakref.
+ This will be called automatically by the collector upon registration of the updater.
- .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
- :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
+ .. seealso:: :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
"""
@@ -150,41 +73,65 @@ def register_collector(self, collector: DataCollectorBase): # noqa
@property
def collector(self) -> torch.collector.DataCollectorBase: # noqa
+ """The collector or container of the receiver.
+
+ Returns `None` if the container is out-of-scope or not set.
+ """
return self._collector_wr() if self._collector_wr is not None else None
+ def _get_server_weights(self) -> Any:
+ """An optional method to gather weights from the server.
+
+ This method is called only if the weights (handle) are not passed directly to the update method.
+ """
+ raise NotImplementedError
+
@abstractmethod
def _sync_weights_with_worker(
- self, worker_id: int | torch.device, server_weights: TensorDictBase
- ) -> TensorDictBase:
- ...
+ self, *, worker_id: int | torch.device | None = None, server_weights: Any
+ ) -> Any:
+ """An abstract method that updates the weights on specified workers.
- @abstractmethod
- def _get_server_weights(self) -> TensorDictBase:
+ The worker id can be `None` if there are no workers associated with the sender.
+ """
...
- @abstractmethod
- def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
- ...
+ def _maybe_map_weights(self, server_weights: Any) -> Any:
+ """Optionally transforms the server weights to match the local weights."""
+ return server_weights
- @abstractmethod
- def all_worker_ids(self) -> list[int] | list[torch.device]:
- ...
+ def all_worker_ids(self) -> list[int] | list[torch.device] | None:
+ """Returns a list of all worker identifiers or `None` if there are no workers associated."""
+ return
def _skip_update(self, worker_id: int | torch.device) -> bool:
+ """A method to determine if a worker should be skipped."""
return False
def __call__(
self,
- weights: TensorDictBase | None = None,
+ weights: Any = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
):
- return self.update_weights(weights=weights, worker_ids=worker_ids)
+ """A proxy to :meth:`~.push_weights`."""
+ return self.push_weights(weights=weights, worker_ids=worker_ids)
- def update_weights(
+ def push_weights(
self,
- weights: TensorDictBase | None = None,
+ *,
+ weights: Any | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
):
+ """Updates the weights of the policy, or on specified / all remote workers.
+
+ Args:
+ weights (Any): The source weights to push to the policy / workers.
+ worker_ids (torch.device | int | list[int] | list[torch.device] | None = None): an optional list of
+ workers to update.
+
+ Returns: nothing.
+
+ """
if weights is None:
# Get the weights on server (local)
server_weights = self._get_server_weights()
@@ -198,47 +145,42 @@ def update_weights(
worker_ids = [worker_ids]
elif worker_ids is None:
worker_ids = self.all_worker_ids()
+ if worker_ids is None:
+ self._sync_weights_with_worker(server_weights=server_weights)
+ return
for worker in worker_ids:
if self._skip_update(worker):
continue
- self._sync_weights_with_worker(worker, server_weights)
+ self._sync_weights_with_worker(
+ worker_id=worker, server_weights=server_weights
+ )
# Specialized classes
-class VanillaWeightUpdater(WeightUpdateReceiverBase):
- """A simple implementation of `WeightUpdateReceiverBase` for updating local policy weights.
+class VanillaWeightUpdater(WeightUpdaterBase):
+ """A simple implementation of :class:`~torchrl.collectors.WeightUpdaterBase` for updating local policy weights.
- The `VanillaLocalWeightUpdater` class provides a basic mechanism for updating the weights
+ The `VanillaWeightSender` class provides a basic mechanism for updating the weights
of a local policy by directly fetching them from a specified source. It is typically used
in scenarios where the weight update logic is straightforward and does not require any
complex mapping or transformation.
- This class is used by default in the `SyncDataCollector` when no custom local weights updater
+ This class is used by default in the `SyncDataCollector` when no custom weight sender
is provided.
- Args:
- weight_getter (Callable[[], TensorDictBase]): A callable that returns the latest policy
- weights from the server or another source.
- policy_weights (TensorDictBase): The current weights of the local policy that need to be updated.
-
- Methods:
- _get_server_weights: Retrieves the latest weights from the specified source.
- _get_local_weights: Accesses the current local policy weights.
- _map_weights: Directly maps server weights to local weights without transformation.
- _maybe_map_weights: Optionally maps server weights to local weights (no-op in this implementation).
- _update_local_weights: Updates the local policy weights with the mapped weights.
-
- .. note::
- This class assumes that the server weights can be directly applied to the local policy
- without any additional processing. If your use case requires more complex weight mapping,
- consider extending `WeightUpdateReceiverBase` with a custom implementation.
-
.. seealso:: :class:`~torchrl.collectors.WeightUpdateReceiverBase` and :class:`~torchrl.collectors.SyncDataCollector`.
+
+ Keyword Args:
+ weight_getter (Callable[[], TensorDictBase], optional): a callable that returns the weights from the server.
+ If not provided, the weights must be passed to :meth:`~.update_weights` directly.
+ policy_weights (TensorDictBase): a TensorDictBase containing the policy weights to be updated
+ in-place.
"""
def __init__(
self,
- weight_getter: Callable[[], TensorDictBase],
+ *,
+ weight_getter: Callable[[], TensorDictBase] | None = None,
policy_weights: TensorDictBase,
):
self.weight_getter = weight_getter
@@ -253,53 +195,46 @@ def _get_local_weights(self) -> TensorDictBase:
def _map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
return server_weights
- def _maybe_map_weights(
- self, server_weights: TensorDictBase, local_weights: TensorDictBase
- ) -> TensorDictBase:
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
return server_weights
- def _update_local_weights(
- self, local_weights: TensorDictBase, mapped_weights: TensorDictBase
+ def _sync_weights_with_worker(
+ self, *, worker_id: None = None, server_weights: TensorDictBase
) -> TensorDictBase:
- if local_weights is None or mapped_weights is None:
+ if server_weights is None:
return
- local_weights.update_(mapped_weights)
+ self.policy_weights.update_(server_weights)
-class MultiProcessedWeightUpdate(WeightUpdateSenderBase):
+class MultiProcessedWeightUpdate(WeightUpdaterBase):
"""A remote weight updater for synchronizing policy weights across multiple processes or devices.
- The `MultiProcessedRemoteWeightUpdate` class provides a mechanism for updating the weights
+ The `MultiProcessedWeightUpdater` class provides a mechanism for updating the weights
of a policy across multiple inference workers in a multiprocessed environment. It is designed
to handle the distribution of weights from a central server to various devices or processes
that are running the policy.
This class is typically used in multiprocessed data collectors where each process or device
requires an up-to-date copy of the policy weights.
- Args:
+ Keyword Args:
get_server_weights (Callable[[], TensorDictBase] | None): A callable that retrieves the
latest policy weights from the server or another centralized source.
policy_weights (Dict[torch.device, TensorDictBase]): A dictionary mapping each device or
process to its current policy weights, which will be updated.
- Methods:
- all_worker_ids: Returns a list of all worker identifiers (devices or processes).
- _sync_weights_with_worker: Synchronizes the server weights with a specific worker.
- _get_server_weights: Retrieves the latest weights from the server.
- _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation).
-
.. note::
This class assumes that the server weights can be directly applied to the workers without
any additional processing. If your use case requires more complex weight mapping or synchronization
- logic, consider extending `WeightUpdateSenderBase` with a custom implementation.
+ logic, consider extending `WeightUpdaterBase` with a custom implementation.
- .. seealso:: :class:`~torchrl.collectors.WeightUpdateSenderBase` and
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
:class:`~torchrl.collectors.DataCollectorBase`.
"""
def __init__(
self,
+ *,
get_server_weights: Callable[[], TensorDictBase] | None,
policy_weights: dict[torch.device, TensorDictBase],
):
@@ -310,13 +245,13 @@ def all_worker_ids(self) -> list[int] | list[torch.device]:
return list(self._policy_weights)
def _sync_weights_with_worker(
- self, worker_id: int | torch.device, server_weights: TensorDictBase
- ) -> TensorDictBase:
+ self, worker_id: int | torch.device, server_weights: TensorDictBase | None
+ ) -> TensorDictBase | None:
if server_weights is None:
return
self._policy_weights[worker_id].data.update_(server_weights)
- def _get_server_weights(self) -> TensorDictBase:
+ def _get_server_weights(self) -> TensorDictBase | None:
# The weights getter can be none if no mapping is required
if self.weights_getter is None:
return
@@ -329,10 +264,10 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
return server_weights
-class RayWeightUpdater(WeightUpdateSenderBase):
+class RayWeightUpdater(WeightUpdaterBase):
"""A remote weight updater for synchronizing policy weights across remote workers using Ray.
- The `RayWeightUpdateSender` class provides a mechanism for updating the weights of a policy
+ The `RayWeightUpdater` class provides a mechanism for updating the weights of a policy
across remote inference workers managed by Ray. It leverages Ray's distributed computing
capabilities to efficiently distribute policy weights to remote collectors.
This class is typically used in distributed data collectors where each remote worker requires
@@ -355,9 +290,9 @@ class RayWeightUpdater(WeightUpdateSenderBase):
.. note::
This class assumes that the server weights can be directly applied to the remote workers without
any additional processing. If your use case requires more complex weight mapping or synchronization
- logic, consider extending `WeightUpdateSenderBase` with a custom implementation.
+ logic, consider extending `WeightUpdaterBase` with a custom implementation.
- .. seealso:: :class:`~torchrl.collectors.WeightUpdateSenderBase` and
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
:class:`~torchrl.collectors.distributed.RayCollector`.
"""
@@ -376,17 +311,15 @@ def __init__(
def all_worker_ids(self) -> list[int] | list[torch.device]:
return list(range(len(self.remote_collectors)))
- def _get_server_weights(self) -> TensorDictBase:
+ def _get_server_weights(self) -> Any:
import ray
return ray.put(self.policy_weights.data)
- def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ def _maybe_map_weights(self, server_weights: Any) -> Any:
return server_weights
- def _sync_weights_with_worker(
- self, worker_id: int, server_weights: TensorDictBase
- ) -> TensorDictBase:
+ def _sync_weights_with_worker(self, worker_id: int, server_weights: Any) -> Any:
torchrl_logger.info(f"syncing weights with worker {worker_id}")
c = self.remote_collectors[worker_id]
c.update_policy_weights_.remote(policy_weights=server_weights)