-
Notifications
You must be signed in to change notification settings - Fork 60
InsertReshardingsPass decomposes matmul/linear+ReduceScatter. #4239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Review updated until commit 0036c5e Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
// rDIDx{i0}, usually a product of an Allreduce or a ReduceScatter, is | ||
// treated as replicated. This way `iDIDx{i0} => rDIDx{i0}` is considered | ||
// resharding. | ||
if (id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is moved from Line 529 so getShardedLogicalAxis also enjoys the fix.
!test |
The failing test is due to a bug in FusionCache, which I'll sort out in a separate PR. test_row_parallel_linear and test_linear_reduce_scatter have the same I believe this is the same error as https://github.com/NVIDIA/Fuser/blob/main/tests/python/multidevice/test_overlap.py#L91 cc @rdspring1 |
!test |
!test |
The former is per module and the latter is per function.
!test |
As a follow-up to #4209, handle ReduceScatter as well.
Fixes #4133