-
Notifications
You must be signed in to change notification settings - Fork 362
[Refactor] Refactor the weight update logic #2914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/vmoens/130/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2914
Note: Links to docs will display an error until the docs builds have been completed. ❌ 13 New Failures, 1 Cancelled Job, 1 Unrelated FailureAs of commit b9e7568 with merge base 0475cbf ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: f685b500e05e61c421297bd9f0215167a4e5642f Pull Request resolved: #2914
ghstack-source-id: fe044d88e919be026afb2e1f8756ff986e9a65b0 Pull Request resolved: #2914
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to rethink about sender and receiver one last time.
I think we always need a sender: in some way, you always need to push the weights somewhere (because vllm will never ask for weights, you push the weights to vllm).
In centralized settings, where you have a central collector orchestrating satellites ones, the responsibility of the central collector is to push weights to the workers (note that this is not the schema that we are using, which is decentralized).
The receiver on the other hand is accessory, it's more like the kind of settings where your worker can ask for weights by itself at a given interval or when some conditions are met.
The update_policy_weights_
function then looks like
def update_policy_weights_(self, *args, **kwargs):
weights = self.receive(*args, **kwargs) # this is a no-op if the weights (hanlde) are in the args
self.send(weights) # this should never be a no-op, as this is where the weight update actually occurs
@@ -63,72 +66,80 @@ def register_collector(self, collector: DataCollectorBase): # noqa | |||
|
|||
@property | |||
def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa | |||
return self._collector_wr() if self._collector_wr is not None else None | |||
"""The collector or container of the receiver. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm saying collector or container because we may want to use these classes with something else than a collector (eg have a sender in a parameter server)
cc @mikaylagawarecki
ghstack-source-id: a53a09e4ff0c8ddd1cde46009481f8a8e43afbd7 Pull Request resolved: #2914
|
||
.. figure:: /_static/img/param-update.svg | ||
|
||
In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you envision this to hold various copies rather than one?
.. figure:: /_static/img/param-update.svg | ||
|
||
In this setting, a parameter server holds various copies of the parameters. The "pulling" of the weights from the | ||
parameter server is handled by the main collector receiver. The main collector server sender instance sends the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main collector server
Is it accurate to think of this as the main thread in RayCollector
?
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on | ||
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to | ||
the policy or to remote inference workers. Every collector -- server or worker -- should have a `WeightUpdateSenderBase` | ||
instance to handle the "push" operation of the weights to the policy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "push/pull" and "sender/receiver" are confusing 🫤 In particular, for me the Receiver == "Puller" part is tough to wrap my head around.
Pull architecture: the client sends the request, and the server responds accordingly
Push architecture: the server pushes data to clients as updates become available
The confusion for me is that I think of sender --> receiver as "sender actively pushes, receiver passively receives". Hence receiver == puller is not intuitive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it
In this context I'm starting to think that having 2 separate classes will always be confusing so perhaps we should just have one that can be customized at will.
In every case I've been dealing with so far it never occured that I could write senders and receivers that would compose freely, so that tells me that making a perfectly composable API may be an illusion.
I'm myself a bit confused about what should live within each of these classes to be honest...
I'll refactor this to have a single Updater class that gives a somewhat unopinionated implementation of the update functionality!
@@ -118,27 +118,44 @@ try to limit the cases where a deepcopy will be executed. The following chart sh | |||
Policy copy decision tree in Collectors. | |||
|
|||
Weight Synchronization in Distributed Environments | |||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
-------------------------------------------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read the diagram above as
CollectorServer
: main thread ofRayCollector
Collector Worker {i}
, remoteDataCollector
If this read is correct, in my mind, it might sometimes make sense to have the receiver on the collector worker rather than the collector server
e.g. If the number of remote workers is sufficiently high, the collector worker might not be colocated with the collector server, in that case it might not make sense to pass the weights "two hops" to get to the worker
Separate qn -- from the diagram it looks like the collector server chooses when to pull from the param server and then "forcefully pushes" to all the workers at once. Is this design intentional? (e.g. Is the purpose of this to batch up workers to different collector servers and update them in batches?)
Stack from ghstack (oldest at bottom):