Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RAISE-BP] Add support for arith.remsi|remui as tt.addptr input #1570

Merged

Conversation

mfrancepillois
Copy link
Contributor

@mfrancepillois mfrancepillois commented Jul 5, 2024

  • Add minimal support for handling arith.remsi|remui as tt.addptr input.
  • Improve handling of unfolded arithmetic operations when evaluating the modulo property and constant values.

Closes Issue: #1436 and #1482

- Add minimal support for handling `arith.remsi|remui` as `tt.addptr` input.
- Improve handling of unfolded arithmetic operations when evaluating the modulo property and constant values.

Closes Issue: #1430 and #1482

Signed-off-by: Maxime France-Pillois <[email protected]>
Copy link
Contributor

@victor-eds victor-eds left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have code actually exercising every case you're implementing here?


// This function folds the `op` operation and returns the constant value if it
// has successfully folded to a constant. Otherwise, it returns `std::nullopt`.
std::optional<int64_t> getFoldedConstantValue(Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we asssert on op having exactly 1 result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to add an assert on the size, because it is not a error. It is not fundamentally wrong to call this function with a op with multiple results. But we cannot fold the operation to get a constant value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then can we early exit in that case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that assert is not for error handling (especially since it's not included in release builds), but for encoding preconditions/assumption. So if you should only ever call this function on ops with 1 results, then putting that in an assert is a good idea.

}
}

if (results.size() != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (results.size() != 1) {
if (results.empty()) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not the case (see the answer to the previous point), the two codes are not semantically equivalent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would make sense if we early exit

Comment on lines +554 to +555
// - (tl.arange(0, end)[:, None] % mod), or
// - (tl.arange(0, end)[None, :] % mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this can cause issues, can we maybe change the order in an earlier pass?

Comment on lines 223 to 230
op->emitWarning(
"TritonRaiseBlockPointer: allowing adding pointer state with "
"modulo in dim 0 to "
"another pointer state with offset in dim 0.\nPlease verify the "
"operand that contains a scalar is meant to increment pointers in "
"dim1. If that is not the case it WILL LEAD TO WRONG COMPILATION "
"RESULTS.\n\nTo avoid this warning, use expand_dims (instead of "
"splat) to explicitly specify which dimension contains the "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if this may lead to wrong code, I'd just fail. We can think of a way to check this, but I'd sacrifice performance for correctness here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I have removed this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally, I've some doubts about removing this special case. Indeed, this case seems to occur quite often when dealing with generated code. It is the case for 03-matrix-multiplication.py, which results in generating this pattern. So, I'm afraid that removing this special case and triggering a failure instead will reduce significantly the interest in this pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As safe handling of modulo requires more care. Issue #1784 has been created.
IMO, we can then keep this unsafe path out the this PR, and design a better support within the newly created issue.

return failure();
}

if (state.getRank() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to switch stmt

@etiotto etiotto requested review from victor-eds and FMarno July 30, 2024 13:02
@whitneywhtsang whitneywhtsang merged commit 57a10af into llvm-target Aug 7, 2024
4 checks passed
@whitneywhtsang whitneywhtsang deleted the maxime/raise-add-ptr-remsi-remui-support branch August 7, 2024 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants