Skip to content

Add bitonic topk #3862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 98 commits into
base: develop
Choose a base branch
from
Open

Add bitonic topk #3862

wants to merge 98 commits into from

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Mar 3, 2025

This implements a faster GPU topk.

  • Update the ref version of topk to take a parameter for the indices, and also updated to handle any layout.
  • Added a gpu bitonic topk version. This will do a bitonic sort per wavefront and then do a partial sort in shared memory to get the final topk values
  • Added a rewrite_topk pass that will split large topk's into 2 operators. This needs the indices to be passed along as they wont be the same for one batch.

template <class T, class U>
constexpr bool float_equal(T x, U y)
{
if constexpr(is_integral<T>{} or is_integral<U>{})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't both T & U be of integral type?

@pfultz2 pfultz2 requested a review from Copilot April 1, 2025 15:19
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 15 out of 35 changed files in this pull request and generated no comments.

Files not reviewed (20)
  • src/CMakeLists.txt: Language not supported
  • src/include/migraphx/op/topk.hpp: Language not supported
  • src/include/migraphx/raw_data.hpp: Language not supported
  • src/include/migraphx/rewrite_topk.hpp: Language not supported
  • src/include/migraphx/shape.hpp: Language not supported
  • src/include/migraphx/tensor_view.hpp: Language not supported
  • src/rewrite_reduce.cpp: Language not supported
  • src/rewrite_topk.cpp: Language not supported
  • src/shape.cpp: Language not supported
  • src/targets/gpu/jit/topk.cpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/bit.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/float_equal.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/index.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/math.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/operators.hpp: Language not supported
  • src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp: Language not supported


constexpr uint64_t bit_ceil(uint64_t x) noexcept
{
if(x <= 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if clause can be removed for a faster overall GPU performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see a difference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if clause is redundant here, thus extra instructions for a GPU. Not sure what kind of difference are you looking for :-)

{
friend constexpr bool operator<(const topk_pair& x, const topk_pair& y)
{
if(not float_equal(x.key, y.key))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float_equal does un-neccessary calculations for the logic required below; it would be more efficient to just check x.key < y.key, first:

This logic could be rewritten to be faster along the lines of:

if(x.key < y.key) return true;
if(x.key > y.key) return false;
return x.val < y.val;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't really make a difference, most likely CSE removes the duplicate comparisons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect or duplicate code is best handled in coding stage :-)

MIGRAPHX_ASSERT(trimmed_n <= n);

array<pair, nper_lane> local_buf;
for(index_int i : range(nper_lane))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why initialize the whole array, when it is being (partially or completely) filled-in later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we cant use uninitialized values.

__shared__ pair buf[aligned_n];
// Copy to LDS
idx.local_stride(aligned_n, [&](auto i) {
auto key = i < x.get_shape().elements() ? x[i] : init;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is x.get_shape().elements() a constexpr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but more importantly this returns an integral_constant so it gets folded in the AST.

template <class T, class U>
struct topk_pair
: conditional_t<(sizeof(T) >= sizeof(U)), topk_pair_t_u<T, U>, topk_pair_u_t<T, U>>,
partially_ordered<topk_pair<T, U>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally, partially_ordered is an incorrect choice for topk_pair, which is actually a strong_ordered or at least a weak_ordered type.


template <class T>
struct partially_ordered
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this struct also contain the primitives/operators that were defined for less_than_comparable? Is there a reason that > isn't defined in this case? I am not a fan of these kind of pre-processor macros;
Regardless this ordering type of partially_ordered isn't correct. We should be able to exactly compare two topk tuples.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats because a < b != b > a when the keys are equal because the order of the indices is always < even when the key comparison is >. Its probably better to remove the comparison operators and use a custom comparator instead as this gets confusing.

@pfultz2
Copy link
Collaborator Author

pfultz2 commented Apr 8, 2025

So the stable sorting is 2.5x slower on some config. We can recover some perf by lowering the split threshold.

Most of the perf cost comes from the wavefront sorting as the number of elements to sort gets larger, it seems to increase the register pressure. For now, I think this could be merged, and I can investigate the perf issues in the future.

@@ -36,6 +36,7 @@ struct module;
/// Rewrite topk operators ideally to better performing operators
struct MIGRAPHX_EXPORT rewrite_topk
{
std::size_t split_threshold = 8192;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these naked constants need a better handling. For example, there is n_threshold, which is double of this split_threshold. Better to derive it as n_threshold/2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split_threshold sets the n_threshold. This is so we can set the threshold different then the default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me set n_threshold to 0 in the constructor instead to avoid confusion here.

return [=](auto p1, auto p2) {
auto [x, i] = p1;
auto [y, j] = p2;
if(not float_equal(x, y))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_pair should be functioning along the lines of std::less(), and there is no reason to fitrst compare using float_equal, fail that test and then run a compare. The comparison should be for less_than or something like that in the first calcuating step. This will make it logical and also more efficient. Thanks.

Copy link
Collaborator Author

@pfultz2 pfultz2 Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is doing lexicographical-like comparison, if the first elements are not equal then do the comparison of the first elements only, and if they are equal do a comparison of the second elements. One difference is that the first elements are compared with the custom comparator that is passed in, and the second elements are always compared with < because the indices always have the same order regardless of what largest is set to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr bool float_equal(T x, U y)
{
    if constexpr(is_integral<T>{} and is_integral<U>{})
        return x == y;
    return not(x < y or x > y);
}

Consider the first case when x1 < y1, and this above float_equal call would first compare <, and then do a follow-up compare() which is another std::less(). This should be just a one comparison and not use two steps. This above code translates to : if(x1 < y1) return std::less(x1,y1). All that was required was a single comparison.

In the second case when x1 > y1, this code translates to if(x1<y1 or x1 > y1) return std::less(x1,y1). This has two extra comparisons.

So the lambda could be written along these lines (no need for float_equal()):

if(x < y) return true;
if(x > y) return false;
return (i < j);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float_equal doesn't do x < y, it does std::nextafter(x, std::numeric_limits<T>::lowest()) <= y and std::nextafter(x, std::numeric_limits<T>::max()) >= y.

So the lambda could be written along these lines:

No it can't, because we arent doing x < y we are doing compare(x, y).

This is for the ref version anyways so I would prefer to keep it simple and straightforward.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
6915ef
Rate old
ecb974
Diff Compare
torchvision-resnet50 64 3,253.49 3,231.64 0.68%
torchvision-resnet50_fp16 64 6,897.82 6,867.39 0.44%
torchvision-densenet121 32 2,442.88 2,432.50 0.43%
torchvision-densenet121_fp16 32 4,239.06 4,212.33 0.63%
torchvision-inceptionv3 32 1,623.24 1,613.30 0.62%
torchvision-inceptionv3_fp16 32 2,707.12 2,696.30 0.40%
cadene-inceptionv4 16 754.25 750.25 0.53%
cadene-resnext64x4 16 814.37 809.71 0.58%
slim-mobilenet 64 6,690.65 6,654.55 0.54%
slim-nasnetalarge 64 197.39 203.03 -2.78%
slim-resnet50v2 64 3,453.13 3,434.71 0.54%
bert-mrpc-onnx 8 1,150.68 1,142.05 0.76%
bert-mrpc-tf 1 475.59 464.19 2.46%
pytorch-examples-wlang-gru 1 476.32 476.27 0.01%
pytorch-examples-wlang-lstm 1 437.02 442.88 -1.32%
torchvision-resnet50_1 1 808.35 813.23 -0.60%
cadene-dpn92_1 1 423.61 421.24 0.56%
cadene-resnext101_1 1 393.47 392.62 0.22%
onnx-taau-downsample 1 397.19 395.87 0.33%
dlrm-criteoterabyte 1 31.95 31.80 0.46%
dlrm-criteoterabyte_fp16 1 51.03 50.96 0.13%
agentmodel 1 9,042.00 9,458.32 -4.40% 🔴
unet_fp16 2 58.61 58.33 0.48%
resnet50v1_fp16 1 1,050.42 1,071.23 -1.94%
resnet50v1_int8 1 857.26 893.66 -4.07% 🔴
bert_base_cased_fp16 64 1,171.05 1,162.33 0.75%
bert_large_uncased_fp16 32 363.50 353.92 2.71%
bert_large_fp16 1 201.73 194.83 3.55% 🔆
distilgpt2_fp16 16 2,230.94 2,215.11 0.71%
yolov5s 1 515.35 543.55 -5.19% 🔴
tinyllama 1 43.85 43.59 0.60%
vicuna-fastchat 1 44.06 44.05 0.02%
whisper-tiny-encoder 1 412.78 411.57 0.29%
whisper-tiny-decoder 1 411.53 411.31 0.05%
llama2_7b 1 nan nan nan%
qwen1.5-7b 1 23.56 23.41 0.62%
phi3-3.8b 1 nan nan nan%
mask-rcnn 1 21.04 18.55 13.44% 🔆
llama3-8b 1 21.73 21.65 0.37%
whisper-large-encoder 1 10.22 10.17 0.49%
whisper-large-decoder 1 98.18 97.78 0.41%
mistral-7b 1 23.76 23.63 0.51%
FLUX.1-schnell 1 893.26 904.60 -1.25%
nan nan nan nan nan%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

❌llama2_7b: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/llama2_7b/decoder_model.onnx


❌#qwen1.5-7b: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER]
[--batch BATCH] [--fill1] [--fill0] [--fp16]
[--argmax] [--verbose] [--tolerance TOLERANCE]
[--input-dim INPUT_DIM] [--target TARGET]
[--ort-run] [--ort-logging]
[--disable-offload-copy] [--disable-fast-math]
[--exhaustive_tune]
accuracy_checker.py: error: unrecognized arguments: input_ids attention_mask position_ids 1 256 @attention_mask 1 256 @position_ids 1 256


❌phi3-3.8b: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/phi3-3.8b/model.onnx


🔴mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ llama3-8b: PASSED: MIGraphX meets tolerance

❌#whisper-large-encoder: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/include/migraphx/op/convolution.hpp:100: normalize_compute_shape: CONVOLUTION: mismatched channel numbers


     ✅ whisper-large-decoder: PASSED: MIGraphX meets tolerance

     ✅ mistral-7b: PASSED: MIGraphX meets tolerance

     ✅ FLUX.1-schnell: PASSED: MIGraphX meets tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants