Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDAIEFuseFillIntoForall] Handle case where fill output is not sliced #976

Merged
merged 6 commits into from
Dec 11, 2024

Conversation

newling
Copy link
Contributor

@newling newling commented Dec 9, 2024

For 2x2 or 4x4 tiling the chain of ops after the fill looks like

 %9 = linalg.fill ins(%cst : f32) ...
 ... 
 %12 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %9)  ...
 ...
 %extracted_slice_19 = tensor.extract_slice %arg5 ... 
 ...
 %19 = linalg.generic  ... outs(%extracted_slice_19 ... )

i.e. the filled value enters an extract_slice inside the scf.forall. But for 1x1 tiling, it looks like

 %9 = linalg.fill ins(%cst : f32)
 ...
 %14 = scf.forall (%arg3, %arg4) in (1, 1) shared_outs(%arg5 = %9)
 ...
 %19 = linalg.generic  ... outs(%arg5... )

i.e. there is no intermediate extract_slice.

Before this PR, the logic was hardcoded to look for an extrac_slice, this PR relaxes this.

Before this PR, 1x1 tiling hits

funcOp->emitOpError("There is no extract tensor slice.");

After this PR, the compilation progresses further (fails much later in objectfifo pipeline, unrelated to this).

@yzhang93
Copy link
Contributor

yzhang93 commented Dec 9, 2024

I think in applications when we see loops like

%14 = scf.forall (%arg3, %arg4) in (1, 1) shared_outs(%arg5 = %9)

The forall should be canonicalized away and we don't have to fuse the fill op into the loop.

However, I can see why this is needed here. We want to keep the thread mapping attribute?

@newling
Copy link
Contributor Author

newling commented Dec 9, 2024

However, I can see why this is needed here. We want to keep the thread mapping attribute?

The canonicalizer doesn't ever remove scf.forall ops with thread/block ids afaik. Which is, as I think you're saying, exactly what we want.

@newling newling marked this pull request as ready for review December 9, 2024 23:07
@yzhang93
Copy link
Contributor

yzhang93 commented Dec 9, 2024

However, I can see why this is needed here. We want to keep the thread mapping attribute?

The canonicalizer doesn't ever remove scf.forall ops with thread/block ids afaik. Which is, as I think you're saying, exactly what we want.

We can always modify the canonicalize patterns to include such case. But the real question is do we want such loop to be canonicalized away? In our current pipeline, a lot of passes have dependency on thread mapping attribute.

@newling
Copy link
Contributor Author

newling commented Dec 9, 2024

However, I can see why this is needed here. We want to keep the thread mapping attribute?

The canonicalizer doesn't ever remove scf.forall ops with thread/block ids afaik. Which is, as I think you're saying, exactly what we want.

We can always modify the canonicalize patterns to include such case. But the real question is do we want such loop to be canonicalized away? In our current pipeline, a lot of passes have dependency on thread mapping attribute.

I agree that if it were canonicalized away, we'd have many issues in our passes. That is why this PR doesn't try to eliminate the scf.forall.

@yzhang93
Copy link
Contributor

I agree that if it were canonicalized away, we'd have many issues in our passes. That is why this PR doesn't try to eliminate the scf.forall.

The reason I'm discussing this here is that when we initially created these individual passes, the general principle was to reuse as many upstream functions as possible, such as scf::tileAndFuseProducerOfSlice in this pass (it is also used in FusePackIntoLoop). I think the upstream function also doesn't handle the case when loop count is 1, since the loop should have already been canonicalized. Is there anything we can do to keep using this upstream function while adding support for the additional case? CC: @Abhishek-Varma

@newling
Copy link
Contributor Author

newling commented Dec 10, 2024

I agree that if it were canonicalized away, we'd have many issues in our passes. That is why this PR doesn't try to eliminate the scf.forall.

The reason I'm discussing this here is that when we initially created these individual passes, the general principle was to reuse as many upstream functions as possible, such as scf::tileAndFuseProducerOfSlice in this pass (it is also used in FusePackIntoLoop). I think the upstream function also doesn't handle the case when loop count is 1, since the loop should have already been canonicalized. Is there anything we can do to keep using this upstream function while adding support for the additional case? CC: @Abhishek-Varma

I've updated it to use the upstream function if there is a slice present. I don't see a way to use that function if there is no slice

if (std::distance(fillUses.begin(), fillUses.end()) != 1) return;
OpOperand &fillUse = *fillUses.begin();
auto forallOp = dyn_cast<scf::ForallOp>(fillUse.getOwner());
if (!forallOp) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these "return" situations, I would add some debug messages instead of silently return.

Comment on lines 45 to 46
ResultRange::use_range fillUses = fillOp->getUses();
if (std::distance(fillUses.begin(), fillUses.end()) != 1) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just check if (fillOp.hasOneUse()).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like that is a method on use_range

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work on ops. fillOp->hasOneUse() doesn't work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

tensor::ExtractSliceOp extractSliceOp;
for (Operation *user : bbArg.getUsers()) {
if (auto nxt = dyn_cast<tensor::ExtractSliceOp>(user)) {
if (extractSliceOp) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you return here. It should be break?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the original function for this purpose.

auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
    auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
    return sliceOp;
  });

Copy link
Contributor Author

@newling newling Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiple extract_slice ops -- bailing as this is unexpected.

return forallOp;
// In the case where there are no extract_slice ops, we manually create the
// fill at the beginning of the forall body.
assert(!extractSliceOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find code with lots of asserts easier to read, but I've changed it to an if-else

Comment on lines 76 to 77
// In the case where there are no extract_slice ops, we manually create the
// fill at the beginning of the forall body.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better if you add some comments explaining when this situation happen (i.e., scf.forall loop count is 1).

Comment on lines 80 to 82
Value scalar = fillOp.value();
Location loc = fillOp.getLoc();
auto fusedFill = rewriter.create<linalg::FillOp>(loc, scalar, bbArg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value scalar = fillOp.value();
Location loc = fillOp.getLoc();
auto fusedFill = rewriter.create<linalg::FillOp>(loc, scalar, bbArg);
auto fusedFill = rewriter.create<linalg::FillOp>(fillOp.getLoc(), fillOp.value(), bbArg);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not show up in the latest revision.

Comment on lines 95 to 96
// Do not use the result of the old fill.
rewriter.replaceAllUsesWith(fillOp.getResults()[0], fillOp.getOutputs()[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this be included in the above replaceUsesWithIf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, different ops

Comment on lines 44 to 51
// check that the operand of scf.forall is not the filled tensor, because the
// fill will take place inside the scf.forall.
// CHECK: %[[FORALL:.*]] = scf.forall (%[[ARG1:.*]]) in (1)
// CHECK-SAME: shared_outs(%[[ARG2:.*]] = %[[FUNCARG]])

// check for the new fill
// CHECK: %[[NEWFILL:.*]] = linalg.fill
// CHECK-SAME: outs(%[[ARG2]] : tensor<8xi8>) -> tensor<8xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's not readable if you mixed CHECK with the comments. I'd prefer you put all the comments above and keep //CHECK section compact.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@newling newling enabled auto-merge (squash) December 11, 2024 20:00
@jtuyls jtuyls disabled auto-merge December 11, 2024 21:19
@jtuyls jtuyls merged commit db10c75 into nod-ai:main Dec 11, 2024
7 checks passed
newling added a commit that referenced this pull request Dec 12, 2024
…983)

This is very similar to #976 

When the AIE grid we're using is `m x n` where one or both of `m` and
`n` is `1`, for a matmul we get pack ops that produce operands for
matmuls directly. i.e. as opposed to `pack->extract_slice->matmul` we
have `pack->matmul`.
@newling newling deleted the fix_fill_pass branch December 12, 2024 23:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants