-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RAISE-BP] Add support for arith.remsi|remui
as tt.addptr
input
#1570
[RAISE-BP] Add support for arith.remsi|remui
as tt.addptr
input
#1570
Conversation
- Add minimal support for handling `arith.remsi|remui` as `tt.addptr` input. - Improve handling of unfolded arithmetic operations when evaluating the modulo property and constant values. Closes Issue: #1430 and #1482 Signed-off-by: Maxime France-Pillois <[email protected]>
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have code actually exercising every case you're implementing here?
|
||
// This function folds the `op` operation and returns the constant value if it | ||
// has successfully folded to a constant. Otherwise, it returns `std::nullopt`. | ||
std::optional<int64_t> getFoldedConstantValue(Operation *op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we asssert on op
having exactly 1 result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer not to add an assert on the size, because it is not a error. It is not fundamentally wrong to call this function with a op with multiple results. But we cannot fold the operation to get a constant value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then can we early exit in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that assert
is not for error handling (especially since it's not included in release builds), but for encoding preconditions/assumption. So if you should only ever call this function on ops with 1 results, then putting that in an assert is a good idea.
} | ||
} | ||
|
||
if (results.size() != 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (results.size() != 1) { | |
if (results.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is not the case (see the answer to the previous point), the two codes are not semantically equivalent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would make sense if we early exit
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Show resolved
Hide resolved
// - (tl.arange(0, end)[:, None] % mod), or | ||
// - (tl.arange(0, end)[None, :] % mod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this can cause issues, can we maybe change the order in an earlier pass?
third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp
Outdated
Show resolved
Hide resolved
op->emitWarning( | ||
"TritonRaiseBlockPointer: allowing adding pointer state with " | ||
"modulo in dim 0 to " | ||
"another pointer state with offset in dim 0.\nPlease verify the " | ||
"operand that contains a scalar is meant to increment pointers in " | ||
"dim1. If that is not the case it WILL LEAD TO WRONG COMPILATION " | ||
"RESULTS.\n\nTo avoid this warning, use expand_dims (instead of " | ||
"splat) to explicitly specify which dimension contains the " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if this may lead to wrong code, I'd just fail. We can think of a way to check this, but I'd sacrifice performance for correctness here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I have removed this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally, I've some doubts about removing this special case. Indeed, this case seems to occur quite often when dealing with generated code. It is the case for 03-matrix-multiplication.py, which results in generating this pattern. So, I'm afraid that removing this special case and triggering a failure instead will reduce significantly the interest in this pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As safe handling of modulo requires more care. Issue #1784 has been created.
IMO, we can then keep this unsafe path out the this PR, and design a better support within the newly created issue.
return failure(); | ||
} | ||
|
||
if (state.getRank() == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to switch stmt
arith.remsi|remui
astt.addptr
input.Closes Issue: #1436 and #1482