Skip to content

[Refactor] Refactor the weight update logic #2914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/vmoens/130/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions docs/source/_static/img/param-update.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 31 additions & 29 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,27 +118,44 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
Policy copy decision tree in Collectors.

Weight Synchronization in Distributed Environments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------------------------------
Copy link

@mikaylagawarecki mikaylagawarecki Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the diagram above as

  1. CollectorServer: main thread of RayCollector
  2. Collector Worker {i}, remote DataCollector

If this read is correct, in my mind, it might sometimes make sense to have the receiver on the collector worker rather than the collector server
e.g. If the number of remote workers is sufficiently high, the collector worker might not be colocated with the collector server, in that case it might not make sense to pass the weights "two hops" to get to the worker

Separate qn -- from the diagram it looks like the collector server chooses when to pull from the param server and then "forcefully pushes" to all the workers at once. Is this design intentional? (e.g. Is the purpose of this to batch up workers to different collector servers and update them in batches?)


In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.

Local and Remote Weight Updaters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sending and receiving model weights with WeightUpdaters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.

- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
the policy or to remote inference workers. Every collector -- server or worker -- should have a `WeightUpdateSenderBase`
instance to handle the "push" operation of the weights to the policy.
Copy link

@mikaylagawarecki mikaylagawarecki Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "push/pull" and "sender/receiver" are confusing 🫤 In particular, for me the Receiver == "Puller" part is tough to wrap my head around.

Pull architecture: the client sends the request, and the server responds accordingly
Push architecture: the server pushes data to clients as updates become available

The confusion for me is that I think of sender --> receiver as "sender actively pushes, receiver passively receives". Hence receiver == puller is not intuitive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it
In this context I'm starting to think that having 2 separate classes will always be confusing so perhaps we should just have one that can be customized at will.
In every case I've been dealing with so far it never occured that I could write senders and receivers that would compose freely, so that tells me that making a perfectly composable API may be an illusion.
I'm myself a bit confused about what should live within each of these classes to be honest...
I'll refactor this to have a single Updater class that gives a somewhat unopinionated implementation of the update functionality!

Users can extend this class to implement custom logic for synchronizing weights across a network of devices or processes.
For "regular", single node collectors, the :class:`~torchrl.collectors.VanillaWeightSender` will be used by default.
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for "pulling" the weights from a
distant parameter server. In many cases, it can be discarded as weights are forcibly pushed to the workers and policy
through `WeightUpdateSenderBase` instances, and a call to `WeightUpdateReceiverBase.update_weights` will be a no-op.
`WeightUpdateReceiverBase` is particularly useful when the training and inference occur on the same machine but on
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
situations where the server decides when to update the worker policies).
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
devices or processes.

Each of these classes has a private `_maybe_map_weights` method that can be overwritten, where a weight transform or
formatting logic can be implemented.

The following figure showcases how a somewhat complex weight update scheme can be implemented using these primitives.

.. figure:: /_static/img/param-update.svg

In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you envision this to hold various copies rather than one?

parameter server is handled by the main collector receiver. The main collector server sender instance sends the

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main collector server

Is it accurate to think of this as the main thread in RayCollector?

parameters to the workers and triggers a remote call to `udpate_policy_weights_` in each or some workers.
Because this is a "push" operation, the receivers in the workers do not need to do anything. The senders
are responsible for writing the parameters to their copy of the policy.

Extending the Updater Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -152,28 +169,12 @@ Default Implementations
~~~~~~~~~~~~~~~~~~~~~~~

For common scenarios, the API provides default implementations of these updaters, such as
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
:class:`~torchrl.collectors.VanillaWeightReceiver`,, :class:`~torchrl.collectors.VanillaWeightSender`,
:class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`, :class:`~torchrl.collectors.RayWeightUpdateSender`,
:class:`~torchrl.collectors.RPCWeightUpdateSender`, and :class:`~torchrl.collectors.DistributedWeightUpdateSender`.
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
distributed systems.

Practical Considerations
~~~~~~~~~~~~~~~~~~~~~~~~

When designing a system that leverages this API, consider the following:

- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
implementation accounts for potential delays and optimizes data transfer where possible.
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
suboptimal policy performance.
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.

By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
scenarios, ensuring that their policies remain up-to-date and performant.

.. currentmodule:: torchrl.collectors

.. autosummary::
Expand All @@ -182,7 +183,8 @@ scenarios, ensuring that their policies remain up-to-date and performant.

WeightUpdateReceiverBase
WeightUpdateSenderBase
VanillaLocalWeightUpdater
VanillaWeightSender
VanillaWeightReceiver
MultiProcessedRemoteWeightUpdate
RayWeightUpdateSender
DistributedWeightUpdateSender
Expand Down
6 changes: 4 additions & 2 deletions torchrl/collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from .weight_update import (
MultiProcessedWeightUpdate,
RayWeightUpdater,
VanillaWeightUpdater,
VanillaWeightReceiver,
VanillaWeightSender,
WeightUpdateReceiverBase,
WeightUpdateSenderBase,
)
Expand All @@ -24,7 +25,8 @@
"RandomPolicy",
"WeightUpdateReceiverBase",
"WeightUpdateSenderBase",
"VanillaWeightUpdater",
"VanillaWeightSender",
"VanillaWeightReceiver",
"RayWeightUpdater",
"MultiProcessedWeightUpdate",
"aSyncDataCollector",
Expand Down
32 changes: 20 additions & 12 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.collectors.weight_update import (
MultiProcessedWeightUpdate,
VanillaWeightUpdater,
VanillaWeightReceiver,
VanillaWeightSender,
WeightUpdateReceiverBase,
WeightUpdateSenderBase,
)
Expand Down Expand Up @@ -299,9 +300,8 @@ def update_policy_weights_(
for the update. If not provided, the method will attempt to fetch the weights using the configured
weight updater.
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the
workers that need to be updated. This is relevant when using a remote weights updater, which must
be specified during the data collector's initialization. If `worker_ids` is provided without a
configured remote weights updater, a TypeError will be raised.
workers that need to be updated. This is relevant when the collector has more than one worker associated
with it.

Raises:
TypeError: If `worker_ids` is provided but no `weight_update_sender` is configured.
Expand All @@ -318,11 +318,8 @@ def update_policy_weights_(

"""
if self.weight_update_receiver is not None:
self.weight_update_receiver(policy_weights, **kwargs)
if self.weight_update_sender is not None:
self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)
elif worker_ids is not None:
raise TypeError("worker_ids was passed but weight_update_sender was None.")
policy_weights = self.weight_update_receiver(policy_weights, **kwargs)
self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)

def __iter__(self) -> Iterator[TensorDictBase]:
try:
Expand Down Expand Up @@ -539,7 +536,7 @@ class SyncDataCollector(DataCollectorBase):
Defaults to ``False``.
weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
or its subclass, responsible for updating the policy weights on the local inference worker.
If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
If not provided, a :class:`~torchrl.collectors.VanillaWeightSender` will be used by default,
which directly fetches and applies the weights from the server.
Consider using a constructor if the updater needs to be serialized.
weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
Expand Down Expand Up @@ -893,10 +890,21 @@ def __init__(
self._frames = 0
self._iter = -1

if weight_update_receiver is None:
weight_update_receiver = VanillaWeightUpdater(
if weight_update_sender is None:
weight_update_sender = VanillaWeightSender(
weight_getter=self.get_weights_fn, policy_weights=self.policy_weights
)
elif not isinstance(weight_update_sender, WeightUpdateSenderBase):
raise TypeError(
"weight_update_sender must be a subclass of WeightUpdateSenderBase"
)

if weight_update_receiver is None:
weight_update_receiver = VanillaWeightReceiver()
elif not isinstance(weight_update_receiver, WeightUpdateReceiverBase):
raise TypeError(
"weight_update_receiver must be a subclass of WeightUpdateReceiverBase"
)

self.weight_update_receiver = weight_update_receiver
self.weight_update_sender = weight_update_sender
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
def all_worker_ids(self) -> list[int] | list[torch.device]:
raise NotImplementedError

def update_weights(
def push_weights(
self,
weights: TensorDictBase | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
def all_worker_ids(self) -> list[int] | list[torch.device]:
raise NotImplementedError

def update_weights(
def push_weights(
self,
weights: TensorDictBase | None = None,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class LLMCollector(SyncDataCollector):
Defaults to `True` when `replay_buffer` is provided, `False` otherwise.
weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
or its subclass, responsible for updating the policy weights on the local inference worker.
If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
If not provided, a :class:`~torchrl.collectors.VanillaWeightSender` will be used by default,
which directly fetches and applies the weights from the server.
Consider using a constructor if the updater needs to be serialized.
weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
Expand Down
Loading
Loading