@@ -118,27 +118,44 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
118
118
Policy copy decision tree in Collectors.
119
119
120
120
Weight Synchronization in Distributed Environments
121
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121
+ --------------------------------------------------
122
+
122
123
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
123
124
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
124
125
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
125
126
126
- Local and Remote Weight Updaters
127
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127
+ Sending and receiving model weights with WeightUpdaters
128
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128
129
129
130
The weight synchronization process is facilitated by two main components: :class: `~torchrl.collectors.WeightUpdateReceiverBase `
130
131
and :class: `~torchrl.collectors.WeightUpdateSenderBase `. These base classes provide a structured interface for
131
132
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132
133
133
- - :class: `~torchrl.collectors.WeightUpdateReceiverBase `: This component is responsible for updating the policy weights on
134
- the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
134
+ - :class: `~torchrl.collectors.WeightUpdateSenderBase `: This component handles the distribution of policy weights to
135
+ the policy or to remote inference workers. Every collector -- server or worker -- should have a `WeightUpdateSenderBase `
136
+ instance to handle the "push" operation of the weights to the policy.
137
+ Users can extend this class to implement custom logic for synchronizing weights across a network of devices or processes.
138
+ For "regular", single node collectors, the :class: `~torchrl.collectors.VanillaWeightSender ` will be used by default.
139
+ - :class: `~torchrl.collectors.WeightUpdateReceiverBase `: This component is responsible for "pulling" the weights from a
140
+ distant parameter server. In many cases, it can be discarded as weights are forcibly pushed to the workers and policy
141
+ through `WeightUpdateSenderBase ` instances, and a call to `WeightUpdateReceiverBase.update_weights ` will be a no-op.
142
+ `WeightUpdateReceiverBase ` is particularly useful when the training and inference occur on the same machine but on
135
143
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136
144
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137
145
situations where the server decides when to update the worker policies).
138
- - :class: `~torchrl.collectors.WeightUpdateSenderBase `: This component handles the distribution of policy weights to
139
- remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140
- the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141
- devices or processes.
146
+
147
+ Each of these classes has a private `_maybe_map_weights ` method that can be overwritten, where a weight transform or
148
+ formatting logic can be implemented.
149
+
150
+ The following figure showcases how a somewhat complex weight update scheme can be implemented using these primitives.
151
+
152
+ .. figure :: /_static/img/param-update.svg
153
+
154
+ In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the
155
+ parameter server is handled by the main collector receiver. The main collector server sender instance sends the
156
+ parameters to the workers and triggers a remote call to `udpate_policy_weights_ ` in each or some workers.
157
+ Because this is a "push" operation, the receivers in the workers do not need to do anything. The senders
158
+ are responsible for writing the parameters to their copy of the policy.
142
159
143
160
Extending the Updater Classes
144
161
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -152,28 +169,12 @@ Default Implementations
152
169
~~~~~~~~~~~~~~~~~~~~~~~
153
170
154
171
For common scenarios, the API provides default implementations of these updaters, such as
155
- :class: `~torchrl.collectors.VanillaLocalWeightUpdater `, :class: `~torchrl.collectors.MultiProcessedRemoteWeightUpdate `,
156
- :class: `~torchrl.collectors.RayWeightUpdateSender `, :class: `~torchrl.collectors.RPCWeightUpdateSender `, and
157
- :class: `~torchrl.collectors.DistributedWeightUpdateSender `.
172
+ :class: `~torchrl.collectors.VanillaWeightReceiver `,, :class: `~torchrl.collectors.VanillaWeightSender `,
173
+ :class: `~torchrl.collectors.MultiProcessedRemoteWeightUpdate `, :class: `~torchrl.collectors.RayWeightUpdateSender `,
174
+ :class: `~torchrl.collectors.RPCWeightUpdateSender `, and :class: ` ~torchrl.collectors. DistributedWeightUpdateSender `.
158
175
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159
176
distributed systems.
160
177
161
- Practical Considerations
162
- ~~~~~~~~~~~~~~~~~~~~~~~~
163
-
164
- When designing a system that leverages this API, consider the following:
165
-
166
- - Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
167
- implementation accounts for potential delays and optimizes data transfer where possible.
168
- - Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
169
- the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
170
- suboptimal policy performance.
171
- - Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
172
- overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
173
-
174
- By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
175
- scenarios, ensuring that their policies remain up-to-date and performant.
176
-
177
178
.. currentmodule :: torchrl.collectors
178
179
179
180
.. autosummary ::
@@ -182,7 +183,8 @@ scenarios, ensuring that their policies remain up-to-date and performant.
182
183
183
184
WeightUpdateReceiverBase
184
185
WeightUpdateSenderBase
185
- VanillaLocalWeightUpdater
186
+ VanillaWeightSender
187
+ VanillaWeightReceiver
186
188
MultiProcessedRemoteWeightUpdate
187
189
RayWeightUpdateSender
188
190
DistributedWeightUpdateSender
0 commit comments