-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "enable TritonFusedRMSNorm with local_map annotation"
**Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
- Loading branch information
Showing
8 changed files
with
45 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters