-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FusedRMSNorm (Triton kernel, +15% eager), Add NPLayerNorm, Enable…
… config selectable Norm Type (#181) This PR has multiple aspects: 1 - Adds a new Triton based Fused RMSNorm I wrote. I've verified it's numerical accuracy on both forward and backward with a unit test. It improves MFU by +15% with FSDP2 7B, and compiled slightly by +1.2%: <img width="545" alt="Screenshot 2024-03-29 at 5 18 14 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/8f16fae9-947b-4720-a370-b954779c33a7"> 2 - Adds norms.py to house all 4 norm types, and standardizes to [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]. Norms.py has a create_norms function that then creates the appropriate norm. 3 - Adds np_layernorm, which is layernorm with no affine transformation. 4 - Updates model.py to now support plug and play of any supported norm. Thus instead of this type of if/then logic in the model class: <img width="928" alt="Screenshot 2024-03-30 at 1 52 07 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/ba7cb976-580f-4471-a79b-a584f7d20693"> We simply have this: <img width="1129" alt="Screenshot 2024-03-30 at 1 52 23 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/aba48b4d-1620-4059-840d-e620468f00f2"> This then allows for easy plug and play of any norm type with no fiddling around in the model code. 5 - updates run_llama_train.sh to randomly select a port vs previous fixed port number. (thanks @yifuwang for this tip!) 6 - Now users can quickly select the norm of their choice via the config file: <img width="774" alt="Screenshot 2024-03-30 at 3 01 43 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/3238b375-dc21-4ee2-a5fa-f6571da79edb"> 7 - adds a NotImpl error if users try to run TP + fused_rnsmorm to avoid any confusion (per @tianyu-l feedback): ~~~ NotImplementedError: fused_rmsnorm not yet compatible with TP. Please use rmsnorm. ~~~
- Loading branch information
Showing
11 changed files
with
346 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.