-
Notifications
You must be signed in to change notification settings - Fork 362
[Refactor] Refactor the weight update logic #2914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/vmoens/130/base
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,27 +118,44 @@ try to limit the cases where a deepcopy will be executed. The following chart sh | |
Policy copy decision tree in Collectors. | ||
|
||
Weight Synchronization in Distributed Environments | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
-------------------------------------------------- | ||
|
||
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the | ||
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible | ||
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. | ||
|
||
Local and Remote Weight Updaters | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
Sending and receiving model weights with WeightUpdaters | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase` | ||
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for | ||
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. | ||
|
||
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on | ||
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on | ||
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to | ||
the policy or to remote inference workers. Every collector -- server or worker -- should have a `WeightUpdateSenderBase` | ||
instance to handle the "push" operation of the weights to the policy. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think "push/pull" and "sender/receiver" are confusing 🫤 In particular, for me the Receiver == "Puller" part is tough to wrap my head around. Pull architecture: the client sends the request, and the server responds accordingly The confusion for me is that I think of sender --> receiver as "sender actively pushes, receiver passively receives". Hence receiver == puller is not intuitive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it |
||
Users can extend this class to implement custom logic for synchronizing weights across a network of devices or processes. | ||
For "regular", single node collectors, the :class:`~torchrl.collectors.VanillaWeightSender` will be used by default. | ||
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for "pulling" the weights from a | ||
distant parameter server. In many cases, it can be discarded as weights are forcibly pushed to the workers and policy | ||
through `WeightUpdateSenderBase` instances, and a call to `WeightUpdateReceiverBase.update_weights` will be a no-op. | ||
`WeightUpdateReceiverBase` is particularly useful when the training and inference occur on the same machine but on | ||
different devices. Users can extend this class to define how weights are fetched from a server and applied locally. | ||
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with | ||
situations where the server decides when to update the worker policies). | ||
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to | ||
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with | ||
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of | ||
devices or processes. | ||
|
||
Each of these classes has a private `_maybe_map_weights` method that can be overwritten, where a weight transform or | ||
formatting logic can be implemented. | ||
|
||
The following figure showcases how a somewhat complex weight update scheme can be implemented using these primitives. | ||
|
||
.. figure:: /_static/img/param-update.svg | ||
|
||
In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you envision this to hold various copies rather than one? |
||
parameter server is handled by the main collector receiver. The main collector server sender instance sends the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is it accurate to think of this as the main thread in |
||
parameters to the workers and triggers a remote call to `udpate_policy_weights_` in each or some workers. | ||
Because this is a "push" operation, the receivers in the workers do not need to do anything. The senders | ||
are responsible for writing the parameters to their copy of the policy. | ||
|
||
Extending the Updater Classes | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
@@ -152,28 +169,12 @@ Default Implementations | |
~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
For common scenarios, the API provides default implementations of these updaters, such as | ||
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`, | ||
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and | ||
:class:`~torchrl.collectors.DistributedWeightUpdateSender`. | ||
:class:`~torchrl.collectors.VanillaWeightReceiver`,, :class:`~torchrl.collectors.VanillaWeightSender`, | ||
:class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`, :class:`~torchrl.collectors.RayWeightUpdateSender`, | ||
:class:`~torchrl.collectors.RPCWeightUpdateSender`, and :class:`~torchrl.collectors.DistributedWeightUpdateSender`. | ||
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale | ||
distributed systems. | ||
|
||
Practical Considerations | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
When designing a system that leverages this API, consider the following: | ||
|
||
- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your | ||
implementation accounts for potential delays and optimizes data transfer where possible. | ||
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across | ||
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to | ||
suboptimal policy performance. | ||
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the | ||
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks. | ||
|
||
By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment | ||
scenarios, ensuring that their policies remain up-to-date and performant. | ||
|
||
.. currentmodule:: torchrl.collectors | ||
|
||
.. autosummary:: | ||
|
@@ -182,7 +183,8 @@ scenarios, ensuring that their policies remain up-to-date and performant. | |
|
||
WeightUpdateReceiverBase | ||
WeightUpdateSenderBase | ||
VanillaLocalWeightUpdater | ||
VanillaWeightSender | ||
VanillaWeightReceiver | ||
MultiProcessedRemoteWeightUpdate | ||
RayWeightUpdateSender | ||
DistributedWeightUpdateSender | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read the diagram above as
CollectorServer
: main thread ofRayCollector
Collector Worker {i}
, remoteDataCollector
If this read is correct, in my mind, it might sometimes make sense to have the receiver on the collector worker rather than the collector server
e.g. If the number of remote workers is sufficiently high, the collector worker might not be colocated with the collector server, in that case it might not make sense to pass the weights "two hops" to get to the worker
Separate qn -- from the diagram it looks like the collector server chooses when to pull from the param server and then "forcefully pushes" to all the workers at once. Is this design intentional? (e.g. Is the purpose of this to batch up workers to different collector servers and update them in batches?)