diff --git a/docs/source/_static/img/param-update.svg b/docs/source/_static/img/param-update.svg
new file mode 100644
index 00000000000..e09039a84c2
--- /dev/null
+++ b/docs/source/_static/img/param-update.svg
@@ -0,0 +1,186 @@
+
+
diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst
index f1ff3b1c8bf..5f02090305e 100644
--- a/docs/source/reference/collectors.rst
+++ b/docs/source/reference/collectors.rst
@@ -118,27 +118,44 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
Policy copy decision tree in Collectors.
Weight Synchronization in Distributed Environments
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+--------------------------------------------------
+
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
-Local and Remote Weight Updaters
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Sending and receiving model weights with WeightUpdaters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
-- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
- the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
+- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
+ the policy or to remote inference workers. Every collector -- server or worker -- should have a `WeightUpdateSenderBase`
+ instance to handle the "push" operation of the weights to the policy.
+ Users can extend this class to implement custom logic for synchronizing weights across a network of devices or processes.
+ For "regular", single node collectors, the :class:`~torchrl.collectors.VanillaWeightSender` will be used by default.
+- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for "pulling" the weights from a
+ distant parameter server. In many cases, it can be discarded as weights are forcibly pushed to the workers and policy
+ through `WeightUpdateSenderBase` instances, and a call to `WeightUpdateReceiverBase.update_weights` will be a no-op.
+ `WeightUpdateReceiverBase` is particularly useful when the training and inference occur on the same machine but on
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
situations where the server decides when to update the worker policies).
-- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
- remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
- the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
- devices or processes.
+
+Each of these classes has a private `_maybe_map_weights` method that can be overwritten, where a weight transform or
+formatting logic can be implemented.
+
+The following figure showcases how a somewhat complex weight update scheme can be implemented using these primitives.
+
+.. figure:: /_static/img/param-update.svg
+
+ In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the
+ parameter server is handled by the main collector receiver. The main collector server sender instance sends the
+ parameters to the workers and triggers a remote call to `udpate_policy_weights_` in each or some workers.
+ Because this is a "push" operation, the receivers in the workers do not need to do anything. The senders
+ are responsible for writing the parameters to their copy of the policy.
Extending the Updater Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -152,28 +169,12 @@ Default Implementations
~~~~~~~~~~~~~~~~~~~~~~~
For common scenarios, the API provides default implementations of these updaters, such as
-:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
-:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
-:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
+:class:`~torchrl.collectors.VanillaWeightReceiver`,, :class:`~torchrl.collectors.VanillaWeightSender`,
+:class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`, :class:`~torchrl.collectors.RayWeightUpdateSender`,
+:class:`~torchrl.collectors.RPCWeightUpdateSender`, and :class:`~torchrl.collectors.DistributedWeightUpdateSender`.
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
distributed systems.
-Practical Considerations
-~~~~~~~~~~~~~~~~~~~~~~~~
-
-When designing a system that leverages this API, consider the following:
-
-- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
- implementation accounts for potential delays and optimizes data transfer where possible.
-- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
- the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
- suboptimal policy performance.
-- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
- overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
-
-By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
-scenarios, ensuring that their policies remain up-to-date and performant.
-
.. currentmodule:: torchrl.collectors
.. autosummary::
@@ -182,7 +183,8 @@ scenarios, ensuring that their policies remain up-to-date and performant.
WeightUpdateReceiverBase
WeightUpdateSenderBase
- VanillaLocalWeightUpdater
+ VanillaWeightSender
+ VanillaWeightReceiver
MultiProcessedRemoteWeightUpdate
RayWeightUpdateSender
DistributedWeightUpdateSender
diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py
index 8e6c0d48fc5..02c61079e69 100644
--- a/torchrl/collectors/__init__.py
+++ b/torchrl/collectors/__init__.py
@@ -15,7 +15,8 @@
from .weight_update import (
MultiProcessedWeightUpdate,
RayWeightUpdater,
- VanillaWeightUpdater,
+ VanillaWeightReceiver,
+ VanillaWeightSender,
WeightUpdateReceiverBase,
WeightUpdateSenderBase,
)
@@ -24,7 +25,8 @@
"RandomPolicy",
"WeightUpdateReceiverBase",
"WeightUpdateSenderBase",
- "VanillaWeightUpdater",
+ "VanillaWeightSender",
+ "VanillaWeightReceiver",
"RayWeightUpdater",
"MultiProcessedWeightUpdate",
"aSyncDataCollector",
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index 5485428a258..02431c0dff6 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -52,7 +52,8 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.collectors.weight_update import (
MultiProcessedWeightUpdate,
- VanillaWeightUpdater,
+ VanillaWeightReceiver,
+ VanillaWeightSender,
WeightUpdateReceiverBase,
WeightUpdateSenderBase,
)
@@ -299,9 +300,8 @@ def update_policy_weights_(
for the update. If not provided, the method will attempt to fetch the weights using the configured
weight updater.
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the
- workers that need to be updated. This is relevant when using a remote weights updater, which must
- be specified during the data collector's initialization. If `worker_ids` is provided without a
- configured remote weights updater, a TypeError will be raised.
+ workers that need to be updated. This is relevant when the collector has more than one worker associated
+ with it.
Raises:
TypeError: If `worker_ids` is provided but no `weight_update_sender` is configured.
@@ -318,11 +318,8 @@ def update_policy_weights_(
"""
if self.weight_update_receiver is not None:
- self.weight_update_receiver(policy_weights, **kwargs)
- if self.weight_update_sender is not None:
- self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)
- elif worker_ids is not None:
- raise TypeError("worker_ids was passed but weight_update_sender was None.")
+ policy_weights = self.weight_update_receiver(policy_weights, **kwargs)
+ self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)
def __iter__(self) -> Iterator[TensorDictBase]:
try:
@@ -539,7 +536,7 @@ class SyncDataCollector(DataCollectorBase):
Defaults to ``False``.
weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
or its subclass, responsible for updating the policy weights on the local inference worker.
- If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
+ If not provided, a :class:`~torchrl.collectors.VanillaWeightSender` will be used by default,
which directly fetches and applies the weights from the server.
Consider using a constructor if the updater needs to be serialized.
weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
@@ -893,10 +890,21 @@ def __init__(
self._frames = 0
self._iter = -1
- if weight_update_receiver is None:
- weight_update_receiver = VanillaWeightUpdater(
+ if weight_update_sender is None:
+ weight_update_sender = VanillaWeightSender(
weight_getter=self.get_weights_fn, policy_weights=self.policy_weights
)
+ elif not isinstance(weight_update_sender, WeightUpdateSenderBase):
+ raise TypeError(
+ "weight_update_sender must be a subclass of WeightUpdateSenderBase"
+ )
+
+ if weight_update_receiver is None:
+ weight_update_receiver = VanillaWeightReceiver()
+ elif not isinstance(weight_update_receiver, WeightUpdateReceiverBase):
+ raise TypeError(
+ "weight_update_receiver must be a subclass of WeightUpdateReceiverBase"
+ )
self.weight_update_receiver = weight_update_receiver
self.weight_update_sender = weight_update_sender
diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py
index 9aff1e018d0..959fc31da70 100644
--- a/torchrl/collectors/distributed/generic.py
+++ b/torchrl/collectors/distributed/generic.py
@@ -1053,7 +1053,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
def all_worker_ids(self) -> list[int] | list[torch.device]:
raise NotImplementedError
- def update_weights(
+ def push_weights(
self,
weights: TensorDictBase | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py
index 28c77daa56c..274f54b27db 100644
--- a/torchrl/collectors/distributed/rpc.py
+++ b/torchrl/collectors/distributed/rpc.py
@@ -887,7 +887,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
def all_worker_ids(self) -> list[int] | list[torch.device]:
raise NotImplementedError
- def update_weights(
+ def push_weights(
self,
weights: TensorDictBase | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
diff --git a/torchrl/collectors/llm.py b/torchrl/collectors/llm.py
index accf3ea80a2..9db4811d6ad 100644
--- a/torchrl/collectors/llm.py
+++ b/torchrl/collectors/llm.py
@@ -70,7 +70,7 @@ class LLMCollector(SyncDataCollector):
Defaults to `True` when `replay_buffer` is provided, `False` otherwise.
weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
or its subclass, responsible for updating the policy weights on the local inference worker.
- If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
+ If not provided, a :class:`~torchrl.collectors.VanillaWeightSender` will be used by default,
which directly fetches and applies the weights from the server.
Consider using a constructor if the updater needs to be serialized.
weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py
index ea92cb37fb2..09357578e0f 100644
--- a/torchrl/collectors/weight_update.py
+++ b/torchrl/collectors/weight_update.py
@@ -18,30 +18,33 @@
class WeightUpdateReceiverBase(metaclass=abc.ABCMeta):
- """A base class for updating local policy weights from a server.
+ """A base class for receiving policy weights from a server.
- This class provides an interface for downloading and updating the weights of a policy
- on a local inference worker. The update process is decentralized, meaning the inference
- worker is responsible for fetching the weights from the server.
+ This class provides an interface for downloading the weights of a policy
+ on a given node. It implements the "pull" operation in a weight update scheme.
+
+ Unlike the :class:`sender <~torchrl.collectors.WeightUpdateSenderBase>` class, this class is optional
+ and should only be used when the collector needs to pull weights from a server.
To extend this class, implement the following abstract methods:
- `_get_server_weights`: Define how to retrieve the weights from the server.
- - `_get_local_weights`: Define how to access the current local weights.
+ Returns a state-dict or some handle to the server's weights to be consumed by the
+ weight sender (see :class:`~torchrl.collectors.WeightUpdateSenderBase`).
- `_maybe_map_weights`: Optionally transform the server weights to match the local weights.
+ By default, no transform is applied.
Attributes:
- policy (Policy, optional): The policy whose weights are to be updated.
- get_weights_from_policy (Callable, optional): A function to extract weights from the policy.
- get_weights_from_server (Callable, optional): A function to fetch weights from the server.
- weight_map_fn (Callable, optional): A function to map server weights to local weights.
- cache_policy_weights (bool): Whether to cache the policy weights locally.
+ collector: The collector (or any container) of the weight receiver. The collector is registered via
+ :meth:`~torchrl.collectors.WeightUpdateReceiverBase.register_collector`.
Methods:
- update_weights: Updates the local weights with the server weights.
-
+ update_weights: Returns a state-dict or a handle to the server's weights to be consumed by the
+ :class:`weight sender `.
+ register_collector: Registers the collector (or any container) in the receiver through a weakref.
+ This will be called automatically by the collector upon registration of the updater.
- .. seealso:: :class:`~torchrl.collectors.RemoteWeightsUpdaterBase` and
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdateSenderBase` and
:meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
"""
@@ -63,72 +66,80 @@ def register_collector(self, collector: DataCollectorBase): # noqa
@property
def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa
- return self._collector_wr() if self._collector_wr is not None else None
+ """The collector or container of the receiver.
- @abstractmethod
- def _get_server_weights(self) -> TensorDictBase:
- ...
-
- @abstractmethod
- def _get_local_weights(self) -> TensorDictBase:
- ...
+ Returns `None` if the container is out-of-scope or not set.
+ """
+ return self._collector_wr() if self._collector_wr is not None else None
- @abstractmethod
- def _maybe_map_weights(
- self, server_weights: TensorDictBase, local_weights: TensorDictBase
- ) -> TensorDictBase:
- ...
+ def _get_server_weights(self) -> Any:
+ """An optional method to gather weights from the server."""
+ raise NotImplementedError
- def _update_local_weights(
- self, local_weights: TensorDictBase, mapped_weights: TensorDictBase
- ) -> TensorDictBase:
- local_weights.update_(mapped_weights)
+ def _maybe_map_weights(self, server_weights: Any, *args, **kwargs) -> Any:
+ """A method to transform the server weights to match the local weights."""
+ return server_weights
def __call__(
self,
- weights: TensorDictBase | None = None,
+ *,
+ weights: Any = None,
):
- return self.update_weights(weights=weights)
+ """A proxy to :meth:`~.pull_weights`."""
+ return self.pull_weights(weights=weights)
- def update_weights(self, weights: TensorDictBase | None = None) -> TensorDictBase:
+ def pull_weights(self, *, weights: Any = None) -> Any:
+ """Pull weights from the server.
+
+ Keyword Args:
+ weights (Any, optional): The weights (handle). If not provided, weights are
+ retrieved from the server via :meth:`~._get_server_weights`.
+
+ Returns:
+ The weights or a handle to them.
+
+ """
if weights is None:
# get server weights (source)
server_weights = self._get_server_weights()
else:
server_weights = weights
- # Get local weights
- local_weights = self._get_local_weights()
# Optionally map the weights
- mapped_weights = self._maybe_map_weights(server_weights, local_weights)
+ mapped_weights = self._maybe_map_weights(server_weights)
- # Update the weights
- self._update_local_weights(local_weights, mapped_weights)
+ return mapped_weights
class WeightUpdateSenderBase(metaclass=abc.ABCMeta):
"""A base class for updating remote policy weights on inference workers.
- This class provides an interface for uploading and synchronizing the weights of a policy
- across remote inference workers. The update process is centralized, meaning the server
- is responsible for distributing the weights to the inference nodes.
+ The weight sender is the central piece of the weight update scheme:
+
+ - In leaf collector nodes, it is responsible for sending the weights to the policy, which can be as simple as
+ updating a state-dict, or more complex if an inference server is being used.
+ - In server collector nodes, it is responsible for sending the weights ("push") to the leaf collectors.
+
+ In a collector, the sender is always called after the receiver.
To extend this class, implement the following abstract methods:
+ - `_get_server_weights`: Define how to retrieve the weights from the server if they are not passed to
+ the updater directly.
- `_sync_weights_with_worker`: Define how to synchronize weights with a specific worker.
- - `_get_server_weights`: Define how to retrieve the weights from the server.
- `_maybe_map_weights`: Optionally transform the server weights before distribution.
- `all_worker_ids`: Provide a list of all worker identifiers.
Attributes:
- policy (Policy, optional): The policy whose weights are to be updated.
+ collector: The collector (or any container) of the weight receiver. The collector is registered via
+ :meth:`~torchrl.collectors.WeightUpdateReceiverBase.register_collector`.
Methods:
update_weights: Updates the weights on specified or all remote workers.
- register_collector: Registers a collector. This should be called automatically by the collector
- upon registration of the updater.
+ register_collector: Registers the collector (or any container) in the receiver through a weakref.
+ This will be called automatically by the collector upon registration of the updater.
- .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdateReceiverBase` and
:meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
"""
@@ -150,41 +161,65 @@ def register_collector(self, collector: DataCollectorBase): # noqa
@property
def collector(self) -> torch.collector.DataCollectorBase: # noqa
+ """The collector or container of the receiver.
+
+ Returns `None` if the container is out-of-scope or not set.
+ """
return self._collector_wr() if self._collector_wr is not None else None
@abstractmethod
def _sync_weights_with_worker(
- self, worker_id: int | torch.device, server_weights: TensorDictBase
- ) -> TensorDictBase:
- ...
+ self, *, worker_id: int | torch.device | None = None, server_weights: Any
+ ) -> Any:
+ """An abstract method that updates the weights on specified workers.
- @abstractmethod
- def _get_server_weights(self) -> TensorDictBase:
+ The worker id can be `None` if there are no workers associated with the sender.
+ """
...
- @abstractmethod
- def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
- ...
+ def _get_server_weights(self) -> Any:
+ """An optional method to gather weights from the server.
- @abstractmethod
- def all_worker_ids(self) -> list[int] | list[torch.device]:
- ...
+ This method is called only if the weights (handle) are not passed directly to the update method.
+ """
+ raise NotImplementedError
+
+ def _maybe_map_weights(self, server_weights: Any) -> Any:
+ """Optionally transforms the server weights to match the local weights."""
+ return server_weights
+
+ def all_worker_ids(self) -> list[int] | list[torch.device] | None:
+ """Returns a list of all worker identifiers or `None` if there are no workers associated."""
+ return
def _skip_update(self, worker_id: int | torch.device) -> bool:
+ """A method to determine if a worker should be skipped."""
return False
def __call__(
self,
- weights: TensorDictBase | None = None,
+ weights: Any = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
):
- return self.update_weights(weights=weights, worker_ids=worker_ids)
+ """A proxy to :meth:`~.push_weights`."""
+ return self.push_weights(weights=weights, worker_ids=worker_ids)
- def update_weights(
+ def push_weights(
self,
- weights: TensorDictBase | None = None,
+ *,
+ weights: Any | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
):
+ """Updates the weights of the policy, or on specified / all remote workers.
+
+ Args:
+ weights (Any): The source weights to push to the policy / workers.
+ worker_ids (torch.device | int | list[int] | list[torch.device] | None = None): an optional list of
+ workers to update.
+
+ Returns: nothing.
+
+ """
if weights is None:
# Get the weights on server (local)
server_weights = self._get_server_weights()
@@ -198,47 +233,49 @@ def update_weights(
worker_ids = [worker_ids]
elif worker_ids is None:
worker_ids = self.all_worker_ids()
+ if worker_ids is None:
+ self._sync_weights_with_worker(server_weights=server_weights)
+ return
for worker in worker_ids:
if self._skip_update(worker):
continue
- self._sync_weights_with_worker(worker, server_weights)
+ self._sync_weights_with_worker(
+ worker_id=worker, server_weights=server_weights
+ )
+
+
+class VanillaWeightReceiver(WeightUpdateReceiverBase):
+ """A simple implementation of :class:`~torchrl.collectors.WeightUpdateReceiverBase` for updating local policy weights."""
+
+ def _get_server_weights(self) -> Any:
+ return None
# Specialized classes
-class VanillaWeightUpdater(WeightUpdateReceiverBase):
- """A simple implementation of `WeightUpdateReceiverBase` for updating local policy weights.
+class VanillaWeightSender(WeightUpdateSenderBase):
+ """A simple implementation of :class:`~torchrl.collectors.WeightUpdateSenderBase` for updating local policy weights.
- The `VanillaLocalWeightUpdater` class provides a basic mechanism for updating the weights
+ The `VanillaWeightSender` class provides a basic mechanism for updating the weights
of a local policy by directly fetching them from a specified source. It is typically used
in scenarios where the weight update logic is straightforward and does not require any
complex mapping or transformation.
- This class is used by default in the `SyncDataCollector` when no custom local weights updater
+ This class is used by default in the `SyncDataCollector` when no custom weight sender
is provided.
- Args:
- weight_getter (Callable[[], TensorDictBase]): A callable that returns the latest policy
- weights from the server or another source.
- policy_weights (TensorDictBase): The current weights of the local policy that need to be updated.
-
- Methods:
- _get_server_weights: Retrieves the latest weights from the specified source.
- _get_local_weights: Accesses the current local policy weights.
- _map_weights: Directly maps server weights to local weights without transformation.
- _maybe_map_weights: Optionally maps server weights to local weights (no-op in this implementation).
- _update_local_weights: Updates the local policy weights with the mapped weights.
-
- .. note::
- This class assumes that the server weights can be directly applied to the local policy
- without any additional processing. If your use case requires more complex weight mapping,
- consider extending `WeightUpdateReceiverBase` with a custom implementation.
-
.. seealso:: :class:`~torchrl.collectors.WeightUpdateReceiverBase` and :class:`~torchrl.collectors.SyncDataCollector`.
+
+ Keyword Args:
+ weight_getter (Callable[[], TensorDictBase], optional): a callable that returns the weights from the server.
+ If not provided, the weights must be passed to :meth:`~.update_weights` directly.
+ policy_weights (TensorDictBase): a TensorDictBase containing the policy weights to be updated
+ in-place.
"""
def __init__(
self,
- weight_getter: Callable[[], TensorDictBase],
+ *,
+ weight_getter: Callable[[], TensorDictBase] | None = None,
policy_weights: TensorDictBase,
):
self.weight_getter = weight_getter
@@ -253,17 +290,15 @@ def _get_local_weights(self) -> TensorDictBase:
def _map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
return server_weights
- def _maybe_map_weights(
- self, server_weights: TensorDictBase, local_weights: TensorDictBase
- ) -> TensorDictBase:
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
return server_weights
- def _update_local_weights(
- self, local_weights: TensorDictBase, mapped_weights: TensorDictBase
+ def _sync_weights_with_worker(
+ self, *, worker_id: None = None, server_weights: TensorDictBase
) -> TensorDictBase:
- if local_weights is None or mapped_weights is None:
+ if server_weights is None:
return
- local_weights.update_(mapped_weights)
+ self.policy_weights.update_(server_weights)
class MultiProcessedWeightUpdate(WeightUpdateSenderBase):
@@ -276,18 +311,12 @@ class MultiProcessedWeightUpdate(WeightUpdateSenderBase):
This class is typically used in multiprocessed data collectors where each process or device
requires an up-to-date copy of the policy weights.
- Args:
+ Keyword Args:
get_server_weights (Callable[[], TensorDictBase] | None): A callable that retrieves the
latest policy weights from the server or another centralized source.
policy_weights (Dict[torch.device, TensorDictBase]): A dictionary mapping each device or
process to its current policy weights, which will be updated.
- Methods:
- all_worker_ids: Returns a list of all worker identifiers (devices or processes).
- _sync_weights_with_worker: Synchronizes the server weights with a specific worker.
- _get_server_weights: Retrieves the latest weights from the server.
- _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation).
-
.. note::
This class assumes that the server weights can be directly applied to the workers without
any additional processing. If your use case requires more complex weight mapping or synchronization
@@ -300,6 +329,7 @@ class MultiProcessedWeightUpdate(WeightUpdateSenderBase):
def __init__(
self,
+ *,
get_server_weights: Callable[[], TensorDictBase] | None,
policy_weights: dict[torch.device, TensorDictBase],
):
@@ -310,13 +340,13 @@ def all_worker_ids(self) -> list[int] | list[torch.device]:
return list(self._policy_weights)
def _sync_weights_with_worker(
- self, worker_id: int | torch.device, server_weights: TensorDictBase
- ) -> TensorDictBase:
+ self, worker_id: int | torch.device, server_weights: TensorDictBase | None
+ ) -> TensorDictBase | None:
if server_weights is None:
return
self._policy_weights[worker_id].data.update_(server_weights)
- def _get_server_weights(self) -> TensorDictBase:
+ def _get_server_weights(self) -> TensorDictBase | None:
# The weights getter can be none if no mapping is required
if self.weights_getter is None:
return
@@ -376,17 +406,15 @@ def __init__(
def all_worker_ids(self) -> list[int] | list[torch.device]:
return list(range(len(self.remote_collectors)))
- def _get_server_weights(self) -> TensorDictBase:
+ def _get_server_weights(self) -> Any:
import ray
return ray.put(self.policy_weights.data)
- def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ def _maybe_map_weights(self, server_weights: Any) -> Any:
return server_weights
- def _sync_weights_with_worker(
- self, worker_id: int, server_weights: TensorDictBase
- ) -> TensorDictBase:
+ def _sync_weights_with_worker(self, worker_id: int, server_weights: Any) -> Any:
torchrl_logger.info(f"syncing weights with worker {worker_id}")
c = self.remote_collectors[worker_id]
c.update_policy_weights_.remote(policy_weights=server_weights)