-
Notifications
You must be signed in to change notification settings - Fork 97
Add bitonic topk #3862
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Add bitonic topk #3862
Conversation
template <class T, class U> | ||
constexpr bool float_equal(T x, U y) | ||
{ | ||
if constexpr(is_integral<T>{} or is_integral<U>{}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't both T
& U
be of integral type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 15 out of 35 changed files in this pull request and generated no comments.
Files not reviewed (20)
- src/CMakeLists.txt: Language not supported
- src/include/migraphx/op/topk.hpp: Language not supported
- src/include/migraphx/raw_data.hpp: Language not supported
- src/include/migraphx/rewrite_topk.hpp: Language not supported
- src/include/migraphx/shape.hpp: Language not supported
- src/include/migraphx/tensor_view.hpp: Language not supported
- src/rewrite_reduce.cpp: Language not supported
- src/rewrite_topk.cpp: Language not supported
- src/shape.cpp: Language not supported
- src/targets/gpu/jit/topk.cpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/bit.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/float_equal.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/index.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/math.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/operators.hpp: Language not supported
- src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp: Language not supported
|
||
constexpr uint64_t bit_ceil(uint64_t x) noexcept | ||
{ | ||
if(x <= 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if
clause can be removed for a faster overall GPU performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont see a difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if
clause is redundant here, thus extra instructions for a GPU. Not sure what kind of difference
are you looking for :-)
{ | ||
friend constexpr bool operator<(const topk_pair& x, const topk_pair& y) | ||
{ | ||
if(not float_equal(x.key, y.key)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float_equal
does un-neccessary calculations for the logic required below; it would be more efficient to just check x.key < y.key
, first:
This logic could be rewritten to be faster along the lines of:
if(x.key < y.key) return true;
if(x.key > y.key) return false;
return x.val < y.val;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't really make a difference, most likely CSE removes the duplicate comparisons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect or duplicate code is best handled in coding stage :-)
MIGRAPHX_ASSERT(trimmed_n <= n); | ||
|
||
array<pair, nper_lane> local_buf; | ||
for(index_int i : range(nper_lane)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why initialize the whole array, when it is being (partially or completely) filled-in later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we cant use uninitialized values.
__shared__ pair buf[aligned_n]; | ||
// Copy to LDS | ||
idx.local_stride(aligned_n, [&](auto i) { | ||
auto key = i < x.get_shape().elements() ? x[i] : init; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is x.get_shape().elements()
a constexpr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but more importantly this returns an integral_constant
so it gets folded in the AST.
template <class T, class U> | ||
struct topk_pair | ||
: conditional_t<(sizeof(T) >= sizeof(U)), topk_pair_t_u<T, U>, topk_pair_u_t<T, U>>, | ||
partially_ordered<topk_pair<T, U>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidentally, partially_ordered
is an incorrect choice for topk_pair
, which is actually a strong_ordered
or at least a weak_ordered
type.
|
||
template <class T> | ||
struct partially_ordered | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this struct also contain the primitives/operators that were defined for less_than_comparable
? Is there a reason that >
isn't defined in this case? I am not a fan of these kind of pre-processor macros;
Regardless this ordering type of partially_ordered
isn't correct. We should be able to exactly compare two topk tuples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats because a < b != b > a
when the keys are equal because the order of the indices is always <
even when the key comparison is >
. Its probably better to remove the comparison operators and use a custom comparator instead as this gets confusing.
So the stable sorting is 2.5x slower on some config. We can recover some perf by lowering the split threshold. Most of the perf cost comes from the wavefront sorting as the number of elements to sort gets larger, it seems to increase the register pressure. For now, I think this could be merged, and I can investigate the perf issues in the future. |
@@ -36,6 +36,7 @@ struct module; | |||
/// Rewrite topk operators ideally to better performing operators | |||
struct MIGRAPHX_EXPORT rewrite_topk | |||
{ | |||
std::size_t split_threshold = 8192; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these naked constants need a better handling. For example, there is n_threshold
, which is double of this split_threshold
. Better to derive it as n_threshold/2
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split_threshold
sets the n_threshold
. This is so we can set the threshold different then the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me set n_threshold
to 0
in the constructor instead to avoid confusion here.
return [=](auto p1, auto p2) { | ||
auto [x, i] = p1; | ||
auto [y, j] = p2; | ||
if(not float_equal(x, y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compare_pair
should be functioning along the lines of std::less()
, and there is no reason to fitrst compare using float_equal
, fail that test and then run a compare
. The comparison should be for less_than
or something like that in the first calcuating step. This will make it logical and also more efficient. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is doing lexicographical-like comparison, if the first elements are not equal then do the comparison of the first elements only, and if they are equal do a comparison of the second elements. One difference is that the first elements are compared with the custom comparator that is passed in, and the second elements are always compared with <
because the indices always have the same order regardless of what largest
is set to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constexpr bool float_equal(T x, U y)
{
if constexpr(is_integral<T>{} and is_integral<U>{})
return x == y;
return not(x < y or x > y);
}
Consider the first case when x1 < y1
, and this above float_equal
call would first compare <
, and then do a follow-up compare()
which is another std::less()
. This should be just a one comparison and not use two steps. This above code translates to : if(x1 < y1) return std::less(x1,y1)
. All that was required was a single comparison.
In the second case when x1 > y1
, this code translates to if(x1<y1 or x1 > y1) return std::less(x1,y1)
. This has two extra comparisons.
So the lambda could be written along these lines (no need for float_equal()
):
if(x < y) return true;
if(x > y) return false;
return (i < j);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float_equal
doesn't do x < y
, it does std::nextafter(x, std::numeric_limits<T>::lowest()) <= y and std::nextafter(x, std::numeric_limits<T>::max()) >= y
.
So the lambda could be written along these lines:
No it can't, because we arent doing x < y
we are doing compare(x, y)
.
This is for the ref version anyways so I would prefer to keep it simple and straightforward.
This build is not recommended to merge 🔴 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output❌llama2_7b: ERROR - check error outputTraceback (most recent call last):File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main model = migraphx.parse_onnx(model_name, default_dim_value=batch) RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/llama2_7b/decoder_model.onnx ❌#qwen1.5-7b: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER][--batch BATCH] [--fill1] [--fill0] [--fp16] [--argmax] [--verbose] [--tolerance TOLERANCE] [--input-dim INPUT_DIM] [--target TARGET] [--ort-run] [--ort-logging] [--disable-offload-copy] [--disable-fast-math] [--exhaustive_tune] accuracy_checker.py: error: unrecognized arguments: input_ids attention_mask position_ids 1 256 @attention_mask 1 256 @position_ids 1 256 ❌phi3-3.8b: ERROR - check error outputTraceback (most recent call last):File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main model = migraphx.parse_onnx(model_name, default_dim_value=batch) RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/phi3-3.8b/model.onnx 🔴mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output❌#whisper-large-encoder: ERROR - check error outputTraceback (most recent call last):File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main model = migraphx.parse_onnx(model_name, default_dim_value=batch) RuntimeError: /src/AMDMIGraphX/src/include/migraphx/op/convolution.hpp:100: normalize_compute_shape: CONVOLUTION: mismatched channel numbers |
This implements a faster GPU topk.
rewrite_topk
pass that will split large topk's into 2 operators. This needs the indices to be passed along as they wont be the same for one batch.