Skip to content

Commit d56ebac

Browse files
committed
[Refactor] Refactor the weight update logic
ghstack-source-id: a53a09e4ff0c8ddd1cde46009481f8a8e43afbd7 Pull Request resolved: #2914
1 parent 6f68da8 commit d56ebac

File tree

8 files changed

+378
-152
lines changed

8 files changed

+378
-152
lines changed

docs/source/_static/img/param-update.svg

+186
Loading

docs/source/reference/collectors.rst

+31-29
Original file line numberDiff line numberDiff line change
@@ -118,27 +118,44 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
118118
Policy copy decision tree in Collectors.
119119

120120
Weight Synchronization in Distributed Environments
121-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121+
--------------------------------------------------
122+
122123
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
123124
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
124125
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
125126

126-
Local and Remote Weight Updaters
127-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127+
Sending and receiving model weights with WeightUpdaters
128+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128129

129130
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
130131
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
131132
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132133

133-
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
134-
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
134+
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
135+
the policy or to remote inference workers. Every collector -- server or worker -- should have a `WeightUpdateSenderBase`
136+
instance to handle the "push" operation of the weights to the policy.
137+
Users can extend this class to implement custom logic for synchronizing weights across a network of devices or processes.
138+
For "regular", single node collectors, the :class:`~torchrl.collectors.VanillaWeightSender` will be used by default.
139+
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for "pulling" the weights from a
140+
distant parameter server. In many cases, it can be discarded as weights are forcibly pushed to the workers and policy
141+
through `WeightUpdateSenderBase` instances, and a call to `WeightUpdateReceiverBase.update_weights` will be a no-op.
142+
`WeightUpdateReceiverBase` is particularly useful when the training and inference occur on the same machine but on
135143
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136144
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137145
situations where the server decides when to update the worker policies).
138-
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
139-
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140-
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141-
devices or processes.
146+
147+
Each of these classes has a private `_maybe_map_weights` method that can be overwritten, where a weight transform or
148+
formatting logic can be implemented.
149+
150+
The following figure showcases how a somewhat complex weight update scheme can be implemented using these primitives.
151+
152+
.. figure:: /_static/img/param-update.svg
153+
154+
In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the
155+
parameter server is handled by the main collector receiver. The main collector server sender instance sends the
156+
parameters to the workers and triggers a remote call to `udpate_policy_weights_` in each or some workers.
157+
Because this is a "push" operation, the receivers in the workers do not need to do anything. The senders
158+
are responsible for writing the parameters to their copy of the policy.
142159

143160
Extending the Updater Classes
144161
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -152,28 +169,12 @@ Default Implementations
152169
~~~~~~~~~~~~~~~~~~~~~~~
153170

154171
For common scenarios, the API provides default implementations of these updaters, such as
155-
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
156-
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
157-
:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
172+
:class:`~torchrl.collectors.VanillaWeightReceiver`,, :class:`~torchrl.collectors.VanillaWeightSender`,
173+
:class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`, :class:`~torchrl.collectors.RayWeightUpdateSender`,
174+
:class:`~torchrl.collectors.RPCWeightUpdateSender`, and :class:`~torchrl.collectors.DistributedWeightUpdateSender`.
158175
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159176
distributed systems.
160177

161-
Practical Considerations
162-
~~~~~~~~~~~~~~~~~~~~~~~~
163-
164-
When designing a system that leverages this API, consider the following:
165-
166-
- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
167-
implementation accounts for potential delays and optimizes data transfer where possible.
168-
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
169-
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
170-
suboptimal policy performance.
171-
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
172-
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
173-
174-
By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
175-
scenarios, ensuring that their policies remain up-to-date and performant.
176-
177178
.. currentmodule:: torchrl.collectors
178179

179180
.. autosummary::
@@ -182,7 +183,8 @@ scenarios, ensuring that their policies remain up-to-date and performant.
182183

183184
WeightUpdateReceiverBase
184185
WeightUpdateSenderBase
185-
VanillaLocalWeightUpdater
186+
VanillaWeightSender
187+
VanillaWeightReceiver
186188
MultiProcessedRemoteWeightUpdate
187189
RayWeightUpdateSender
188190
DistributedWeightUpdateSender

torchrl/collectors/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from .weight_update import (
1616
MultiProcessedWeightUpdate,
1717
RayWeightUpdater,
18-
VanillaWeightUpdater,
18+
VanillaWeightReceiver,
19+
VanillaWeightSender,
1920
WeightUpdateReceiverBase,
2021
WeightUpdateSenderBase,
2122
)
@@ -24,7 +25,8 @@
2425
"RandomPolicy",
2526
"WeightUpdateReceiverBase",
2627
"WeightUpdateSenderBase",
27-
"VanillaWeightUpdater",
28+
"VanillaWeightSender",
29+
"VanillaWeightReceiver",
2830
"RayWeightUpdater",
2931
"MultiProcessedWeightUpdate",
3032
"aSyncDataCollector",

torchrl/collectors/collectors.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
from torchrl.collectors.utils import split_trajectories
5353
from torchrl.collectors.weight_update import (
5454
MultiProcessedWeightUpdate,
55-
VanillaWeightUpdater,
55+
VanillaWeightReceiver,
56+
VanillaWeightSender,
5657
WeightUpdateReceiverBase,
5758
WeightUpdateSenderBase,
5859
)
@@ -299,9 +300,8 @@ def update_policy_weights_(
299300
for the update. If not provided, the method will attempt to fetch the weights using the configured
300301
weight updater.
301302
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the
302-
workers that need to be updated. This is relevant when using a remote weights updater, which must
303-
be specified during the data collector's initialization. If `worker_ids` is provided without a
304-
configured remote weights updater, a TypeError will be raised.
303+
workers that need to be updated. This is relevant when the collector has more than one worker associated
304+
with it.
305305
306306
Raises:
307307
TypeError: If `worker_ids` is provided but no `weight_update_sender` is configured.
@@ -318,11 +318,8 @@ def update_policy_weights_(
318318
319319
"""
320320
if self.weight_update_receiver is not None:
321-
self.weight_update_receiver(policy_weights, **kwargs)
322-
if self.weight_update_sender is not None:
323-
self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)
324-
elif worker_ids is not None:
325-
raise TypeError("worker_ids was passed but weight_update_sender was None.")
321+
policy_weights = self.weight_update_receiver(policy_weights, **kwargs)
322+
self.weight_update_sender(policy_weights, worker_ids=worker_ids, **kwargs)
326323

327324
def __iter__(self) -> Iterator[TensorDictBase]:
328325
try:
@@ -539,7 +536,7 @@ class SyncDataCollector(DataCollectorBase):
539536
Defaults to ``False``.
540537
weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
541538
or its subclass, responsible for updating the policy weights on the local inference worker.
542-
If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
539+
If not provided, a :class:`~torchrl.collectors.VanillaWeightSender` will be used by default,
543540
which directly fetches and applies the weights from the server.
544541
Consider using a constructor if the updater needs to be serialized.
545542
weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`
@@ -893,10 +890,21 @@ def __init__(
893890
self._frames = 0
894891
self._iter = -1
895892

896-
if weight_update_receiver is None:
897-
weight_update_receiver = VanillaWeightUpdater(
893+
if weight_update_sender is None:
894+
weight_update_sender = VanillaWeightSender(
898895
weight_getter=self.get_weights_fn, policy_weights=self.policy_weights
899896
)
897+
elif not isinstance(weight_update_sender, WeightUpdateSenderBase):
898+
raise TypeError(
899+
"weight_update_sender must be a subclass of WeightUpdateSenderBase"
900+
)
901+
902+
if weight_update_receiver is None:
903+
weight_update_receiver = VanillaWeightReceiver()
904+
elif not isinstance(weight_update_receiver, WeightUpdateReceiverBase):
905+
raise TypeError(
906+
"weight_update_receiver must be a subclass of WeightUpdateReceiverBase"
907+
)
900908

901909
self.weight_update_receiver = weight_update_receiver
902910
self.weight_update_sender = weight_update_sender

torchrl/collectors/distributed/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
10531053
def all_worker_ids(self) -> list[int] | list[torch.device]:
10541054
raise NotImplementedError
10551055

1056-
def update_weights(
1056+
def push_weights(
10571057
self,
10581058
weights: TensorDictBase | None = None,
10591059
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,

torchrl/collectors/distributed/rpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
887887
def all_worker_ids(self) -> list[int] | list[torch.device]:
888888
raise NotImplementedError
889889

890-
def update_weights(
890+
def push_weights(
891891
self,
892892
weights: TensorDictBase | None = None,
893893
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,

torchrl/collectors/llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class LLMCollector(SyncDataCollector):
7070
Defaults to `True` when `replay_buffer` is provided, `False` otherwise.
7171
weight_update_receiver (WeightUpdateReceiverBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateReceiverBase`
7272
or its subclass, responsible for updating the policy weights on the local inference worker.
73-
If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
73+
If not provided, a :class:`~torchrl.collectors.VanillaWeightSender` will be used by default,
7474
which directly fetches and applies the weights from the server.
7575
Consider using a constructor if the updater needs to be serialized.
7676
weight_update_sender (WeightUpdateSenderBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdateSenderBase`

0 commit comments

Comments
 (0)