-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] IdModel: Fix invalid promotion selection #3877
base: main
Are you sure you want to change the base?
Conversation
!test |
Review updated until commit 2ea029b Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
@@ -654,7 +654,7 @@ TEST_F(Tutorial, IdModelReshapeAnalysis) { | |||
fusion.addOutput(tv3); | |||
|
|||
IdModel id_model(&fusion); | |||
ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); | |||
ValGraph& exact_graph = id_model.buildExactGraph(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated fix
!test |
Fixes #3702
computeCoveredGroups
is used in the loop promotion analysis to find a promotion ID that has dependencies to all input exact groups. Specifically, it traverses the exact graph in a topological order and it assigns for each exact group a set of dependent input exact groups. This information is used to find promotion IDs since a promotion ID of a loop group needs to have all dependent input groups for all IDs of the given loop group.Here's a simple example. Suppose there are two 2D input tensors,
t0
andt1
, and they are added together as follows:The resulting tensors and IDs look like below:
Specifically,
i6, i7 and i8
form a loop group withi8
(ori7
) being its promotion ID. The dependent input exact groups ofi8
are{i0, i2, i4}
and{i3, i5}
. Note that broadcast IDs are ignored. Notice that set covers all of the dependent groups of the member IDs of the loop group. That's true obviously fori7
as it's exact mapped withi8
. It's also true fori6
, which only has{i0, i2, i4}
as its dependent group (the dependency tob1
is ignored as its a broadcast ID).This
computeCoveredGroups
just works fine in common transformation patters, where we tend to first do all merges and then splits. However, it has a problem when there's some split followed by a merge. Consider a fusion with 2 input tensors,t0
andt1
as shown below:Similar to the previous case, the final merged IDs form a loop group of
{i8, i9, i10}
. In this case, it should be promoted toi9
ori10
(they are exact mapped and it doesn't matter which one is picked). However,computeCoveredGroups
assigns the exact group of{i2, i3}
not just to the exact group of{i9, i10}
but also to the exact groups of{i0, i4, i6}
, {i5, i7}and as well as
{i8}since all of them have dependencies to
{i2, i3}. This is problematic since then
findPromotionfOfLoopGroup[may incorrectly choose](https://github.com/NVIDIA/Fuser/blob/main/csrc/id_model/loop_promotion.cpp#L926-L934), e.g.,
i8` as the promotion of the loop group.The root cause of the issue lies in
[computeCoveredGroups](https://github.com/NVIDIA/Fuser/blob/main/csrc/id_model/loop_promotion.cpp#L779-L822)
. It just propagates input exact groups down to consumer exact groups through exact expr groups, no matter what actual expr a group represents. Specifically, when it's a split, the input groups covered by the split input ID are just propagated to the split output groups, but the output groups do not actually cover the entire input groups but they should be considered to cover only the split portion of the input groups. In the above example case,{i0, i4, i6}
and{i5, i7}
should cover only the outer and inner split of{i2, i3}
, respectively. Therefore,{i8}
only inherits the outer split of{i0, i4, i6}
, whereas{i9, i10}
gets both the outer and inner split groups, which allowsfindPromotionOfLoopGroup
to correctly identify{i9, i10}
as the promotion of the loop group.Previously, the coverage information is just a map from each exact
ValGroup
to aValGroups
of dependent input groups. Since that is not enough to represent the coverage information as precisely as discussed above, this PR introduces an additional data structure,CoveredGroup
, which represents a covered exact group of an exact group. Since each exact group may cover multiple groups,std::unordered_set<CoveredGroup>
is used to represent all covered groups for an exact group.CoveredGroup
optionally encodes the split of a covered group, like the above{i0, i4, i6}
case. TheCoveredGroup
of{i0, i4, i6}
would have{i0, i4, i6}
asgroup_
and also the covered groups of its split input group,{i2, i3}
, assplit_in
withis_inner_
being false. Similarly,{i5, i7}
would have{i5, i7}
as itsgroup_
and{i2, i3}
as itssplit_in_
butis_inner_
being true. Therefore,{i8}
would just get the same coverage info propagated from{i0, i4, i6}
, whereas{i9, i10}
would get that as well as the one from{i5, i7}
, effectively gathering both the outer and inner output groups of the split. This way, we avoid losing information of covered input groups when propagating the information through split exprs.