-
Notifications
You must be signed in to change notification settings - Fork 555
Add support for AC budget API #1731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @tohskai! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read the blog and find the memory budget idea is cool. Have you try out the implementation on some model (eg, llama3) with torch.compile
? I'm curious does it works end to end and if the performance are better
Thanks for sharing! We would love to see more verifications - eg, correctness and loss curves , and performance analysis on titan supported models (llama3, etc) cc @soulitzer for reviewing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't done runs on llama3, but on our benchmarks on it showed significant improvements over regular SAC. This is why I wanted to upstream this :)
I agree with @wwwjn that
We would love to see more verifications - eg, correctness and loss curves , and performance analysis on titan supported models (llama3, etc)
Please refer to https://github.com/pytorch/torchtitan/blob/main/CONTRIBUTING.md#proof-of-value
Should this support selection of https://github.com/pytorch/pytorch/blob/main/torch/_functorch/config.py#L147-L169 |
I addressed your comments, rebased to avoid conflicts and added the other parts of the api, but I am totally okay with reverting it back, just waiting for your opinion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Sounds good in general. Left some comments. Please see if they make sense.
) | ||
model.layers.register_module(layer_id, transformer_block) | ||
if ac_config.mode == "memory_budget": | ||
assert (model_compile_enabled is True), "Memory budget mode requires model to be compiled" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert (model_compile_enabled is True), "Memory budget mode requires model to be compiled" | |
assert model_compile_enabled, "Memory budget mode requires model to be compiled" |
activation_memory_budget_runtime_estimator: Literal["flops", "profile"] = "flops" | ||
""" | ||
This controls how we estimate the runtime when deciding what the cheapest | ||
operators to recompute are. The 3 options are | ||
"flops": Bases it off of the flop count provided by torch.utils.flop_counter | ||
"profile": Benchmarks each operator to come up with a runtime | ||
"testing": Returns 1 for everything | ||
""" | ||
activation_memory_budget_solver: Literal["dp", "greedy", "ilp"] = "dp" | ||
""" | ||
This controls the solver used for the 0-1 knapsack. By default we use a | ||
quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" | ||
(which has a scipy dependency). | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to tell if users should change them. Maybe let's remove them for now and see if people have complaints.
quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" | ||
(which has a scipy dependency). | ||
""" | ||
visualize_memory_budget_pareto: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one I'm not sure. It could be useful because otherwise people have no idea what value should they set the budge?
Btw, the picture you linked doesn't seem to be in "increments of 0.5". It seems 0.05.
Whether to stop recomputing early when all activations have already been | ||
rematerialized. | ||
""" | ||
activation_memory_budget: float = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe set 0.5 by default, o/w if user turns on memory_budget
without tuning this, nothing would happen.
This dumps out a SVG visualization of the expected runtime vs. activation | ||
memory tradeoffs for all memory budget values from 0 to 1 in increments of | ||
0.5. See an example here: | ||
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please describe what folder it'll dump into.
Whether to stop recomputing early when all activations have already been | ||
rematerialized. | ||
""" | ||
activation_memory_budget: float = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please add space between configs, o/w hard to tell if a message is associated with config above / below.
Whether to stop recomputing early when all activations have already been | ||
rematerialized. | ||
""" | ||
activation_memory_budget: float = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
activation_
is redundant, so we can just call it
activation_memory_budget: float = 1.0 | |
memory_budget: float = 1.0 |
0.0 corresponds to the activation memory from applying | ||
activation checkpointing to the full compiled region, and 1.0 corresponds to | ||
the activation memory from the default runtime-optimized strategy. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a link to the post https://pytorch.org/blog/activation-checkpointing-techniques/
Inspired by the blogpost:
https://pytorch.org/blog/activation-checkpointing-techniques/