From ce00b89a35e3323b6a3b0185f5bb3d73592156df Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 14 Oct 2025 21:30:31 -0700 Subject: [PATCH 01/28] fix localized sort --- src/module.cpp | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 0b94a36c5a0..0188d68f633 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1614,15 +1614,34 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) { if(fusion_ins.count(it) == 0) { - auto next = std::next(it); // move_instruction updates the iterator - this->move_instruction(it, start_ins); - it = next; + // only move if none of its inputs are after + bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), + [&](instruction_ref input) { + if(!has_instruction(input)) + return false; + auto input_pos = std::distance(this->begin(), input); + auto start_pos = std::distance(this->begin(), start_ins); + auto curr_pos = std::distance(this->begin(), it); + return input_pos > start_pos && input_pos < curr_pos; + }); + + if(!has_input_in_range) + { + auto next = std::next(it); + this->move_instruction(it, start_ins); + it = next; + } + else + { + ++it; + } } else { ++it; } } + assert(this->validate() == this->end()); } bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); } From 92aead88c4062e74858eb716f6f1bfc05eb2547d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 14 Oct 2025 22:01:41 -0700 Subject: [PATCH 02/28] unsupported fusions and clang format --- src/module.cpp | 8 +++---- src/targets/gpu/fuse_mlir.cpp | 36 ++++++++++++++++++++++++------- test/verify/CMakeLists.txt | 4 ---- test/verify/test_conv_add_dot.cpp | 1 + test/verify/test_dot_add_dot.cpp | 1 + 5 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 0188d68f633..b17cf985b0e 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1615,14 +1615,12 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) if(fusion_ins.count(it) == 0) { // only move if none of its inputs are after - bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), + bool has_input_in_range = + std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { if(!has_instruction(input)) return false; - auto input_pos = std::distance(this->begin(), input); - auto start_pos = std::distance(this->begin(), start_ins); - auto curr_pos = std::distance(this->begin(), it); - return input_pos > start_pos && input_pos < curr_pos; + return std::distance(start_ins, input) > 0 && std::distance(input, it) > 0; }); if(!has_input_in_range) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 83d30c50bba..b5261203482 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -788,6 +788,22 @@ struct find_mlir_fused_geg_ops { mlir_mode conv_mode = mlir_mode::none; mlir_mode dot_mode = mlir_mode::none; + std::string gfx_name; + + // check if individual GEMM meets architecture requirements + bool is_gemm_supported(instruction_ref ins) const + { + // on navi3x (gfx110x), wmma doesn't support fp32, so skip fp32 GEMMs + // one gemm being f32 is sufficient to turn off this fusion + if(starts_with(gfx_name, "gfx110")) + { + if(ins->get_shape().type() == shape::type_t::float_type) + { + return false; + } + } + return true; + } /* * Matches: @@ -797,11 +813,15 @@ struct find_mlir_fused_geg_ops */ auto matcher() const { - auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)) - .bind("first_gemm_based_op"); + auto gemm_supported = match::make_basic_pred_matcher( + [this](instruction_ref ins) { return is_gemm_supported(ins); }); + + auto first_dot_or_conv = + match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode))(gemm_supported) + .bind("first_gemm_based_op"); auto elemwise = mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise"); - return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise)) + return is_mlir_dot(dot_mode)(gemm_supported)(match::any_of[match::inputs()](elemwise)) .bind("second_gemm_op"); } @@ -820,10 +840,11 @@ struct find_mlir_fused_geg_ops })) return; - // only one input to second_gemm should depend on elemwise + // only one input to second_gemm should depend on elemwise or first_gemm auto second_gemm_inputs = second_gemm_ins->inputs(); if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) { - return i != elemwise_ins and reaches(elemwise_ins, i); + return (i != elemwise_ins and reaches(elemwise_ins, i)) or + (i != first_gemm_ins and reaches(first_gemm_ins, i)); })) return; @@ -890,13 +911,11 @@ struct find_mlir_fused_geg_ops // so that we can safely place the fused mod in the multi-out case at the beginning of the // chain mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins); - auto fused_ins = mpm.get_module().insert_instruction(first_gemm_ins, mlir_op{second_gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); - if(first_gemm_has_multi_outs or elemwise_has_multi_outs) { std::size_t output_idx = 0; @@ -1452,7 +1471,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches( mpm, find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), - .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + .dot_mode = get_mode("fused_dot", mlir_mode::fast), + .gfx_name = device_name}); mpm.run_pass(dead_code_elimination{}); } diff --git a/test/verify/CMakeLists.txt b/test/verify/CMakeLists.txt index 56b0a2cf2ba..daa5a22138e 100644 --- a/test/verify/CMakeLists.txt +++ b/test/verify/CMakeLists.txt @@ -42,7 +42,3 @@ foreach(SECTION general reduce rnn conv gemm) ) endif() endforeach() - -# TODO: remove this when MIGraphX disables attn-like offloading to rocMLIR -# f32 not supported on navi3x -set_tests_properties(test_verify_rnn PROPERTIES ENVIRONMENT "MIGRAPHX_ENABLE_MLIR_GEG_FUSION=0") diff --git a/test/verify/test_conv_add_dot.cpp b/test/verify/test_conv_add_dot.cpp index a4ba2018f7f..af0fd090e69 100644 --- a/test/verify/test_conv_add_dot.cpp +++ b/test/verify/test_conv_add_dot.cpp @@ -63,3 +63,4 @@ struct test_conv_add_dot : verify_program> template struct test_conv_add_dot; template struct test_conv_add_dot; +template struct test_conv_add_dot; diff --git a/test/verify/test_dot_add_dot.cpp b/test/verify/test_dot_add_dot.cpp index a087e60e018..ff56a982de0 100644 --- a/test/verify/test_dot_add_dot.cpp +++ b/test/verify/test_dot_add_dot.cpp @@ -49,3 +49,4 @@ struct test_dot_add_dot : verify_program> template struct test_dot_add_dot; template struct test_dot_add_dot; +template struct test_dot_add_dot; From 55e08683702195d488e0fdd6dd4ba05c61b80e52 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 14 Oct 2025 22:03:58 -0700 Subject: [PATCH 03/28] comment --- src/module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/module.cpp b/src/module.cpp index b17cf985b0e..3d7b37d11be 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1614,7 +1614,7 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) { if(fusion_ins.count(it) == 0) { - // only move if none of its inputs are after + // only move if none of its inputs are after start_ins bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { From cd069bcde5f7ce2a2136a4e70ce2da78eaebe20a Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 14 Oct 2025 22:23:59 -0700 Subject: [PATCH 04/28] format --- src/module.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 3d7b37d11be..ad2cfc69c5f 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1615,16 +1615,16 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) if(fusion_ins.count(it) == 0) { // only move if none of its inputs are after start_ins - bool has_input_in_range = - std::any_of(it->inputs().begin(), it->inputs().end(), - [&](instruction_ref input) { - if(!has_instruction(input)) + bool has_input_in_range = + std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { + if(!has_instruction(input)) return false; return std::distance(start_ins, input) > 0 && std::distance(input, it) > 0; }); - - if(!has_input_in_range) - { + + + if(!has_input_in_range) + { auto next = std::next(it); this->move_instruction(it, start_ins); it = next; From 93e96073ae9d5f00e04c14f9c8fe49edbba20050 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 14 Oct 2025 22:41:29 -0700 Subject: [PATCH 05/28] format --- src/module.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index ad2cfc69c5f..80a74433eac 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1617,14 +1617,13 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) // only move if none of its inputs are after start_ins bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { - if(!has_instruction(input)) + if(!has_instruction(input)) return false; return std::distance(start_ins, input) > 0 && std::distance(input, it) > 0; }); - - if(!has_input_in_range) - { + if(!has_input_in_range) + { auto next = std::next(it); this->move_instruction(it, start_ins); it = next; From bac1bcae1df143ab7c842ad0d7ab0423c9f60f6d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 15 Oct 2025 11:27:02 -0700 Subject: [PATCH 06/28] ci failures --- src/module.cpp | 4 +- src/targets/gpu/fuse_mlir.cpp | 2 +- test/verify/test_dot_add_transpose_dot.cpp | 56 ++++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 test/verify/test_dot_add_transpose_dot.cpp diff --git a/src/module.cpp b/src/module.cpp index 80a74433eac..a25e257b01d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1614,12 +1614,12 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) { if(fusion_ins.count(it) == 0) { - // only move if none of its inputs are after start_ins bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { if(!has_instruction(input)) return false; - return std::distance(start_ins, input) > 0 && std::distance(input, it) > 0; + // verify: start_ins < input < it + return std::find(std::next(start_ins), it, input) != it; }); if(!has_input_in_range) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index b5261203482..05295e9fa85 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -843,7 +843,7 @@ struct find_mlir_fused_geg_ops // only one input to second_gemm should depend on elemwise or first_gemm auto second_gemm_inputs = second_gemm_ins->inputs(); if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) { - return (i != elemwise_ins and reaches(elemwise_ins, i)) or + return (i != elemwise_ins and reaches(elemwise_ins, i)) and (i != first_gemm_ins and reaches(first_gemm_ins, i)); })) return; diff --git a/test/verify/test_dot_add_transpose_dot.cpp b/test/verify/test_dot_add_transpose_dot.cpp new file mode 100644 index 00000000000..7c3738c6297 --- /dev/null +++ b/test/verify/test_dot_add_transpose_dot.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_dot_add_transpose_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{DType, {512, 256}}; + migraphx::shape s2{DType, {256, 32}}; + migraphx::shape s3{DType, {512, 32}}; + auto a = mm->add_parameter("a", s1); // 512, 256 + auto b = mm->add_parameter("b", s2); // 256, 32 + auto c = mm->add_parameter("c", s3); // 512, 32 + auto d = mm->add_parameter("d", s3); // 512, 32 + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // (512, 256) x (256, 32) = (512, 32) + auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); // (512, 32) + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d); // (32, 512) + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, transpose); // (512, 32) x (32, 512) + mm->add_return({dot2}); + return p; + } +}; + +template struct test_dot_add_transpose_dot; +template struct test_dot_add_transpose_dot; +template struct test_dot_add_transpose_dot; From 37c31cd91516745079ab8236c800e020ff055861 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 15 Oct 2025 11:35:35 -0700 Subject: [PATCH 07/28] change geg fuse_mlir tests to be f16 to avoid needing to worry about gpu for these specific tests --- test/gpu/fuse_mlir.cpp | 48 +++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 5532da8832b..d81af2f77b3 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1776,10 +1776,10 @@ TEST_CASE(unpack_fp4_dot_odd) TEST_CASE(dot_add_dot) { - migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; - migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; - migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::half_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::half_type, {4, 2}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -1817,7 +1817,7 @@ TEST_CASE(dot_add_dot) TEST_CASE(dot_add_dot_square) { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -1854,9 +1854,9 @@ TEST_CASE(dot_add_dot_square) TEST_CASE(dot_mul_dot) { - migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; - migraphx::shape s3{migraphx::shape::float_type, {4, 5}}; + migraphx::shape s1{migraphx::shape::half_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::half_type, {4, 5}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -1931,10 +1931,10 @@ TEST_CASE(conv_add) TEST_CASE(conv_add_dot) { - migraphx::shape is{migraphx::shape::float_type, {2, 4, 8, 8}}; - migraphx::shape ys{migraphx::shape::float_type, {2, 8, 8, 8}}; - migraphx::shape ws{migraphx::shape::float_type, {8, 4, 1, 1}}; - migraphx::shape zs{migraphx::shape::float_type, {2, 8, 8, 4}}; + migraphx::shape is{migraphx::shape::half_type, {2, 4, 8, 8}}; + migraphx::shape ys{migraphx::shape::half_type, {2, 8, 8, 8}}; + migraphx::shape ws{migraphx::shape::half_type, {8, 4, 1, 1}}; + migraphx::shape zs{migraphx::shape::half_type, {2, 8, 8, 4}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2020,9 +2020,9 @@ TEST_CASE(dot_multi_user_add) TEST_CASE(dot_add_multi_user_dot) // GEG fusion has two outputs, E has external user { - migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {3, 5}}; - migraphx::shape s3{migraphx::shape::float_type, {5, 2}}; + migraphx::shape s1{migraphx::shape::half_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {3, 5}}; + migraphx::shape s3{migraphx::shape::half_type, {5, 2}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2069,7 +2069,7 @@ TEST_CASE(dot_add_multi_user_dot) TEST_CASE(dot_add_multi_user_dot_with_transpose) // GEG fusion has two outputs, E has external user { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2120,7 +2120,7 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) TEST_CASE(dot_add_multi_user_dot_two_externals) // GEG fusion has two outputs, E has external user { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2174,7 +2174,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) // of will-be-fused ops // This also shows the relu being fused, since it is a unary op { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2231,7 +2231,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) // This also shows the relu being fused, since it is a unary op. // Result should be, and is, equivalent to the previous test { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2286,7 +2286,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) // longer chain of logic, for both cases of input fusion. When enabled, // the mul gets fused into the GEG fusion. { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2378,7 +2378,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) // Testing inputs being defined within the span of will-be-fused ops, including // longer chain of logic { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2467,7 +2467,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) TEST_CASE(dot_pw_multi_user_dot) // GEG fusion has two outputs, E has external user, E is multiple elemwise ops { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2522,7 +2522,7 @@ TEST_CASE(dot_pw_multi_user_dot) TEST_CASE(dot_multi_user_add_dot) // GEG fusion has two outputs (first G has external user) { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -2569,7 +2569,7 @@ TEST_CASE(dot_multi_user_add_dot) TEST_CASE(dot_add_dot_both_multi_user) // GEG fusion has three outputs (first G has external user, E has external user) { - migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; { auto* mm = p1.get_main_module(); From 89fb484a31b469c83ae9fa3298ddb13e0dbdfd07 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 15 Oct 2025 11:37:56 -0700 Subject: [PATCH 08/28] comment --- src/module.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/module.cpp b/src/module.cpp index a25e257b01d..965abea5adf 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1614,6 +1614,7 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) { if(fusion_ins.count(it) == 0) { + // only move if none of its inputs are after start_ins bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { if(!has_instruction(input)) From b3dd3bf2e030364481b5b304493393b2312f37c9 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 15 Oct 2025 11:42:33 -0700 Subject: [PATCH 09/28] style issues --- src/module.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 965abea5adf..7b00c596030 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1617,21 +1617,22 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) // only move if none of its inputs are after start_ins bool has_input_in_range = std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) { - if(!has_instruction(input)) + if(not has_instruction(input)) return false; // verify: start_ins < input < it return std::find(std::next(start_ins), it, input) != it; }); - if(!has_input_in_range) + if(has_input_in_range) { - auto next = std::next(it); - this->move_instruction(it, start_ins); - it = next; + // input is after start_ins, meaning can't move this instruction + ++it; } else { - ++it; + auto next = std::next(it); + this->move_instruction(it, start_ins); + it = next; } } else From bcf4173510a389b78fcf13b541e029c0cd6f2a25 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 15 Oct 2025 12:11:52 -0700 Subject: [PATCH 10/28] format --- test/verify/test_dot_add_transpose_dot.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/verify/test_dot_add_transpose_dot.cpp b/test/verify/test_dot_add_transpose_dot.cpp index 7c3738c6297..50b0ceaecb8 100644 --- a/test/verify/test_dot_add_transpose_dot.cpp +++ b/test/verify/test_dot_add_transpose_dot.cpp @@ -41,11 +41,13 @@ struct test_dot_add_transpose_dot : verify_programadd_parameter("b", s2); // 256, 32 auto c = mm->add_parameter("c", s3); // 512, 32 auto d = mm->add_parameter("d", s3); // 512, 32 - auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // (512, 256) x (256, 32) = (512, 32) - auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); // (512, 32) + auto dot1 = mm->add_instruction( + migraphx::make_op("dot"), a, b); // (512, 256) x (256, 32) = (512, 32) + auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); // (512, 32) auto transpose = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d); // (32, 512) - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, transpose); // (512, 32) x (32, 512) + auto dot2 = + mm->add_instruction(migraphx::make_op("dot"), add, transpose); // (512, 32) x (32, 512) mm->add_return({dot2}); return p; } From 927ec55aed18a9f1ed000967993825de941c953d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 15 Oct 2025 12:26:27 -0700 Subject: [PATCH 11/28] format again --- test/verify/test_dot_add_transpose_dot.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/verify/test_dot_add_transpose_dot.cpp b/test/verify/test_dot_add_transpose_dot.cpp index 50b0ceaecb8..2ec723df5cc 100644 --- a/test/verify/test_dot_add_transpose_dot.cpp +++ b/test/verify/test_dot_add_transpose_dot.cpp @@ -44,9 +44,9 @@ struct test_dot_add_transpose_dot : verify_programadd_instruction( migraphx::make_op("dot"), a, b); // (512, 256) x (256, 32) = (512, 32) auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); // (512, 32) - auto transpose = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d); // (32, 512) - auto dot2 = + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d); // (32, 512) + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, transpose); // (512, 32) x (32, 512) mm->add_return({dot2}); return p; From 92b4435a294f414f57abf33a83616d19f0c6f27a Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 20 Oct 2025 21:35:55 -0700 Subject: [PATCH 12/28] f32 disabled for all navis --- src/targets/gpu/fuse_mlir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 05295e9fa85..f0747140f95 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -793,9 +793,9 @@ struct find_mlir_fused_geg_ops // check if individual GEMM meets architecture requirements bool is_gemm_supported(instruction_ref ins) const { - // on navi3x (gfx110x), wmma doesn't support fp32, so skip fp32 GEMMs + // on navis, wmma doesn't support fp32, so skip fp32 GEMMs // one gemm being f32 is sufficient to turn off this fusion - if(starts_with(gfx_name, "gfx110")) + if(starts_with(gfx_name, "gfx11") or starts_with(gfx_name, "gfx12")) { if(ins->get_shape().type() == shape::type_t::float_type) { From cbc483ea3ab3dc34641d7537aab5adb151ea091e Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 20 Oct 2025 22:09:13 -0700 Subject: [PATCH 13/28] force ordering of gemm1 inputs in geg --- src/targets/gpu/fuse_mlir.cpp | 2 +- test/gpu/fuse_mlir.cpp | 56 +++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f0747140f95..2408566d17b 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -821,7 +821,7 @@ struct find_mlir_fused_geg_ops .bind("first_gemm_based_op"); auto elemwise = mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise"); - return is_mlir_dot(dot_mode)(gemm_supported)(match::any_of[match::inputs()](elemwise)) + return is_mlir_dot(dot_mode)(gemm_supported)(match::arg(0)(elemwise)) .bind("second_gemm_op"); } diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index d81af2f77b3..3e18261e4c7 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1815,6 +1815,62 @@ TEST_CASE(dot_add_dot) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(dot_add_dot_ABC) +// MLIR currently only supports (A*B)*C GEG patterns +{ + migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::half_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::half_type, {4, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} + auto add = + add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2} + mm->add_return({dot2}); + } + run_pass(p1); + // ensure "geg" is present. Earlier tests ensure the fusion is correct. This is just to ensure it happens. + std::stringstream ss; + ss << p1; + std::string program_str = ss.str(); + EXPECT(program_str.find("geg") != std::string::npos); +} + +TEST_CASE(dot_add_dot_CAB) +// MLIR currently does not support C*(A*B) GEG patterns +{ + migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::half_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::half_type, {4, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} + auto add = + add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), y, add); // {4, 2}*{2, 4} = {4, 4} + mm->add_return({dot2}); + } + run_pass(p1); + // ensure "geg" is not in the fused program + std::stringstream ss; + ss << p1; + std::string program_str = ss.str(); + EXPECT(program_str.find("geg") == std::string::npos); +} + TEST_CASE(dot_add_dot_square) { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; From 5cf6f93444db9862d7f449228d5857e663b6f1f0 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 20 Oct 2025 23:20:32 -0700 Subject: [PATCH 14/28] toggle off multi out support for intermediates in g+g fusion --- src/targets/gpu/fuse_mlir.cpp | 11 +++++++ .../gpu/include/migraphx/gpu/fuse_mlir.hpp | 1 + test/gpu/fuse_mlir.cpp | 30 ++++++++++++------- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 2408566d17b..642d566e026 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -143,6 +143,13 @@ bool mlir_attention_enabled(context* ctx) #endif } +// for g+g: check if multi-users for intermediates is supported +bool mlir_geg_multi_user_intermediates_supported() +{ + // return no for now. TODO: have a toggle + return false; +} + #ifdef MIGRAPHX_MLIR struct mlir_op @@ -858,6 +865,10 @@ struct find_mlir_fused_geg_ops bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1; bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1; + // if we have multi outs for either of the intermediates, check if this is supported first + if((first_gemm_has_multi_outs or elemwise_has_multi_outs) and not mlir_geg_multi_user_intermediates_supported()) + return; + // add the first gemm to the module std::vector first_gemm_mapped_inputs; first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size()); diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index 7f694f4432b..78740e1e01d 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -35,6 +35,7 @@ namespace gpu { MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled(context* ctx); +MIGRAPHX_GPU_EXPORT bool mlir_geg_multi_user_intermediates_supported(); struct MIGRAPHX_GPU_EXPORT fuse_mlir { diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 3e18261e4c7..e9b1040ad88 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -2075,6 +2075,7 @@ TEST_CASE(dot_multi_user_add) TEST_CASE(dot_add_multi_user_dot) // GEG fusion has two outputs, E has external user +// not currently supported in rocMLIR { migraphx::shape s1{migraphx::shape::half_type, {3, 3}}; migraphx::shape s2{migraphx::shape::half_type, {3, 5}}; @@ -2117,13 +2118,14 @@ TEST_CASE(dot_add_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } TEST_CASE(dot_add_multi_user_dot_with_transpose) // GEG fusion has two outputs, E has external user +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2168,13 +2170,14 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } TEST_CASE(dot_add_multi_user_dot_two_externals) // GEG fusion has two outputs, E has external user +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2219,7 +2222,7 @@ TEST_CASE(dot_add_multi_user_dot_two_externals) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t1, external_t2}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2229,6 +2232,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) // Base case for testing inputs being defined within the span // of will-be-fused ops // This also shows the relu being fused, since it is a unary op +// currently not supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2276,7 +2280,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2286,6 +2290,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) // Testing inputs being defined within the span of will-be-fused ops // This also shows the relu being fused, since it is a unary op. // Result should be, and is, equivalent to the previous test +// currently not supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2331,7 +2336,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2341,6 +2346,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) // Base case for inputs being defined within the span of will-be-fused ops, including // longer chain of logic, for both cases of input fusion. When enabled, // the mul gets fused into the GEG fusion. +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2421,7 +2427,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) EXPECT(p1.sort() == p3.sort()); @@ -2433,6 +2439,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) // GEG fusion has two outputs, E has external user // Testing inputs being defined within the span of will-be-fused ops, including // longer chain of logic +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2512,7 +2519,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) EXPECT(p1.sort() == p3.sort()); @@ -2522,6 +2529,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) TEST_CASE(dot_pw_multi_user_dot) // GEG fusion has two outputs, E has external user, E is multiple elemwise ops +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2570,13 +2578,14 @@ TEST_CASE(dot_pw_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_mul, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } TEST_CASE(dot_multi_user_add_dot) // GEG fusion has two outputs (first G has external user) +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2617,13 +2626,14 @@ TEST_CASE(dot_multi_user_add_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } TEST_CASE(dot_add_dot_both_multi_user) // GEG fusion has three outputs (first G has external user, E has external user) +// not currently supported in rocMLIR { migraphx::shape s{migraphx::shape::half_type, {1, 3, 3}}; migraphx::program p1; @@ -2666,7 +2676,7 @@ TEST_CASE(dot_add_dot_both_multi_user) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_elemwise, get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } From 23585afc4989a2078920aeda708f7cb9d493cafb Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 21 Oct 2025 16:20:35 -0700 Subject: [PATCH 15/28] format --- src/targets/gpu/fuse_mlir.cpp | 3 ++- test/gpu/fuse_mlir.cpp | 37 +++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 642d566e026..459f574527d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -866,7 +866,8 @@ struct find_mlir_fused_geg_ops bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1; // if we have multi outs for either of the intermediates, check if this is supported first - if((first_gemm_has_multi_outs or elemwise_has_multi_outs) and not mlir_geg_multi_user_intermediates_supported()) + if((first_gemm_has_multi_outs or elemwise_has_multi_outs) and + not mlir_geg_multi_user_intermediates_supported()) return; // add the first gemm to the module diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index e9b1040ad88..19c4155e529 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1832,11 +1832,12 @@ TEST_CASE(dot_add_dot_ABC) auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} auto add = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2, 4} - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2} mm->add_return({dot2}); } run_pass(p1); - // ensure "geg" is present. Earlier tests ensure the fusion is correct. This is just to ensure it happens. + // ensure "geg" is present. Earlier tests ensure the fusion is correct. This is just to ensure + // it happens. std::stringstream ss; ss << p1; std::string program_str = ss.str(); @@ -1860,7 +1861,7 @@ TEST_CASE(dot_add_dot_CAB) auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} auto add = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2, 4} - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), y, add); // {4, 2}*{2, 4} = {4, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), y, add); // {4, 2}*{2, 4} = {4, 4} mm->add_return({dot2}); } run_pass(p1); @@ -2118,7 +2119,8 @@ TEST_CASE(dot_add_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2170,7 +2172,8 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2222,7 +2225,8 @@ TEST_CASE(dot_add_multi_user_dot_two_externals) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t1, external_t2}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2280,7 +2284,8 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2336,7 +2341,8 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2427,7 +2433,8 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) EXPECT(p1.sort() == p3.sort()); @@ -2519,7 +2526,8 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) EXPECT(p1.sort() == p3.sort()); @@ -2578,7 +2586,8 @@ TEST_CASE(dot_pw_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_mul, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2626,7 +2635,8 @@ TEST_CASE(dot_multi_user_add_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } @@ -2676,7 +2686,8 @@ TEST_CASE(dot_add_dot_both_multi_user) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_elemwise, get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); } From cc75a89bc915a54e288844ceefe55b72ec481cee Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 00:58:11 -0700 Subject: [PATCH 16/28] format --- test/gpu/fuse_mlir.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 19c4155e529..3576b68ead7 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -2119,7 +2119,7 @@ TEST_CASE(dot_add_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2172,7 +2172,7 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2225,7 +2225,7 @@ TEST_CASE(dot_add_multi_user_dot_two_externals) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t1, external_t2}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2284,7 +2284,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2341,7 +2341,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2433,7 +2433,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) @@ -2526,7 +2526,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) @@ -2586,7 +2586,7 @@ TEST_CASE(dot_pw_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_mul, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2635,7 +2635,7 @@ TEST_CASE(dot_multi_user_add_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); @@ -2686,7 +2686,7 @@ TEST_CASE(dot_add_dot_both_multi_user) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_elemwise, get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) return; EXPECT(p1.sort() == p2.sort()); From 8ea8264e4c26f499e225be568cd263cee89083fc Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 01:31:54 -0700 Subject: [PATCH 17/28] make flag in pass class for toggling multi out intermediates support --- src/targets/gpu/fuse_mlir.cpp | 14 ++--- .../gpu/include/migraphx/gpu/fuse_mlir.hpp | 1 + test/gpu/fuse_mlir.cpp | 59 +++++++++---------- 3 files changed, 33 insertions(+), 41 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 459f574527d..a750fc3c5ab 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -143,13 +143,6 @@ bool mlir_attention_enabled(context* ctx) #endif } -// for g+g: check if multi-users for intermediates is supported -bool mlir_geg_multi_user_intermediates_supported() -{ - // return no for now. TODO: have a toggle - return false; -} - #ifdef MIGRAPHX_MLIR struct mlir_op @@ -796,6 +789,7 @@ struct find_mlir_fused_geg_ops mlir_mode conv_mode = mlir_mode::none; mlir_mode dot_mode = mlir_mode::none; std::string gfx_name; + bool enable_geg_multi_out_intermediates = false; // check if individual GEMM meets architecture requirements bool is_gemm_supported(instruction_ref ins) const @@ -867,7 +861,7 @@ struct find_mlir_fused_geg_ops // if we have multi outs for either of the intermediates, check if this is supported first if((first_gemm_has_multi_outs or elemwise_has_multi_outs) and - not mlir_geg_multi_user_intermediates_supported()) + not enable_geg_multi_out_intermediates) return; // add the first gemm to the module @@ -1484,7 +1478,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const mpm, find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), .dot_mode = get_mode("fused_dot", mlir_mode::fast), - .gfx_name = device_name}); + .gfx_name = device_name, + .enable_geg_multi_out_intermediates = enable_geg_multi_out_intermediates + }); mpm.run_pass(dead_code_elimination{}); } diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index 78740e1e01d..82edc35e581 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -41,6 +41,7 @@ struct MIGRAPHX_GPU_EXPORT fuse_mlir { context* ctx = nullptr; bool enable_extra = false; + bool enable_geg_multi_out_intermediates = false; std::string name() const { return "gpu::fuse_mlir"; } void apply(module_pass_manager& mpm) const; }; diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 3576b68ead7..4b5694acf33 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -52,10 +52,15 @@ struct non_mlir_op } }; -static void run_pass(migraphx::program& p) +static void run_pass(migraphx::program& p, bool enable_geg_multi_user = false) { migraphx::run_passes( - p, {migraphx::gpu::fuse_mlir{.enable_extra = true}, migraphx::dead_code_elimination{}}); + p, + {migraphx::gpu::fuse_mlir{ + .enable_extra = true, + .enable_geg_multi_out_intermediates = enable_geg_multi_user + }, + migraphx::dead_code_elimination{}}); } template @@ -2095,7 +2100,7 @@ TEST_CASE(dot_add_multi_user_dot) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2119,8 +2124,7 @@ TEST_CASE(dot_add_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2146,7 +2150,7 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2172,8 +2176,7 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2199,7 +2202,7 @@ TEST_CASE(dot_add_multi_user_dot_two_externals) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, external_t1, external_t2}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2225,8 +2228,7 @@ TEST_CASE(dot_add_multi_user_dot_two_externals) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t1, external_t2}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2256,7 +2258,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2284,8 +2286,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2313,7 +2314,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2341,8 +2342,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2374,7 +2374,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2433,8 +2433,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) EXPECT(p1.sort() == p3.sort()); @@ -2467,7 +2466,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2526,8 +2525,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_add, external_t}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) EXPECT(p1.sort() == p3.sort()); @@ -2560,7 +2558,7 @@ TEST_CASE(dot_pw_multi_user_dot) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({elemwise, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2586,8 +2584,7 @@ TEST_CASE(dot_pw_multi_user_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); mm->add_return({get_mul, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2611,7 +2608,7 @@ TEST_CASE(dot_multi_user_add_dot) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); mm->add_return({dot2, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2635,8 +2632,7 @@ TEST_CASE(dot_multi_user_add_dot) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } @@ -2660,7 +2656,7 @@ TEST_CASE(dot_add_dot_both_multi_user) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); mm->add_return({add, dot2, transpose}); } - run_pass(p1); + run_pass(p1, true); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2686,8 +2682,7 @@ TEST_CASE(dot_add_dot_both_multi_user) migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); mm->add_return({get_elemwise, get_dot2, transpose}); } - if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}) or - not migraphx::gpu::mlir_geg_multi_user_intermediates_supported()) + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) return; EXPECT(p1.sort() == p2.sort()); } From e6d7b47e2a50c033b3775746d59fb460e4ebc0db Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 03:28:38 -0700 Subject: [PATCH 18/28] format/tidy --- src/targets/gpu/fuse_mlir.cpp | 4 ++-- test/gpu/fuse_mlir.cpp | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a750fc3c5ab..3583b33ee5b 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1479,8 +1479,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), .dot_mode = get_mode("fused_dot", mlir_mode::fast), .gfx_name = device_name, - .enable_geg_multi_out_intermediates = enable_geg_multi_out_intermediates - }); + .enable_geg_multi_out_intermediates = + enable_geg_multi_out_intermediates}); mpm.run_pass(dead_code_elimination{}); } diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 4b5694acf33..f7dc0de5a2b 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -56,10 +56,8 @@ static void run_pass(migraphx::program& p, bool enable_geg_multi_user = false) { migraphx::run_passes( p, - {migraphx::gpu::fuse_mlir{ - .enable_extra = true, - .enable_geg_multi_out_intermediates = enable_geg_multi_user - }, + {migraphx::gpu::fuse_mlir{.enable_extra = true, + .enable_geg_multi_out_intermediates = enable_geg_multi_user}, migraphx::dead_code_elimination{}}); } @@ -1820,7 +1818,7 @@ TEST_CASE(dot_add_dot) EXPECT(p1.sort() == p2.sort()); } -TEST_CASE(dot_add_dot_ABC) +TEST_CASE(dot_add_dot_abc) // MLIR currently only supports (A*B)*C GEG patterns { migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; @@ -1849,7 +1847,7 @@ TEST_CASE(dot_add_dot_ABC) EXPECT(program_str.find("geg") != std::string::npos); } -TEST_CASE(dot_add_dot_CAB) +TEST_CASE(dot_add_dot_cab) // MLIR currently does not support C*(A*B) GEG patterns { migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; From dcf75fe7d97d8b54c1cfc21738a99eba3c9245fd Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 08:16:21 -0700 Subject: [PATCH 19/28] edit test ensuring geg fusion happens on correctly oriented gemms doesnt assert when gfx is navi --- test/gpu/fuse_mlir.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index f7dc0de5a2b..f01cb4c1eda 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1844,6 +1844,12 @@ TEST_CASE(dot_add_dot_abc) std::stringstream ss; ss << p1; std::string program_str = ss.str(); + + // regardless if the matmul is correctly oriented, geg should not happen on navi + auto device_name = migraphx::gpu::get_device_name(); + bool is_navi = migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); + if(is_navi) + EXPECT(program_str.find("geg") == std::string::npos); EXPECT(program_str.find("geg") != std::string::npos); } From 461bb73bdcb56d748c485a7747bfc5ec92c5b076 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 08:18:06 -0700 Subject: [PATCH 20/28] format --- test/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index f01cb4c1eda..eafcafe90b6 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -58,7 +58,7 @@ static void run_pass(migraphx::program& p, bool enable_geg_multi_user = false) p, {migraphx::gpu::fuse_mlir{.enable_extra = true, .enable_geg_multi_out_intermediates = enable_geg_multi_user}, - migraphx::dead_code_elimination{}}); + migraphx::dead_code_elimination{}}); } template From 2346b17158a8d0b467e403f258603ae94ac055aa Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 08:38:38 -0700 Subject: [PATCH 21/28] format --- test/gpu/fuse_mlir.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index eafcafe90b6..a0241b9bb8a 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1847,7 +1847,8 @@ TEST_CASE(dot_add_dot_abc) // regardless if the matmul is correctly oriented, geg should not happen on navi auto device_name = migraphx::gpu::get_device_name(); - bool is_navi = migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); + bool is_navi = + migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); if(is_navi) EXPECT(program_str.find("geg") == std::string::npos); EXPECT(program_str.find("geg") != std::string::npos); From ef3aa000408fb9fd40fbe5d632077a7ffb63e9f6 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 22 Oct 2025 08:51:03 -0700 Subject: [PATCH 22/28] format --- test/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index a0241b9bb8a..e50dea7d9a5 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1847,7 +1847,7 @@ TEST_CASE(dot_add_dot_abc) // regardless if the matmul is correctly oriented, geg should not happen on navi auto device_name = migraphx::gpu::get_device_name(); - bool is_navi = + bool is_navi = migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); if(is_navi) EXPECT(program_str.find("geg") == std::string::npos); From d8df55340a44e37bd055522d85e02bf2fa8015a8 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 24 Oct 2025 00:16:37 +0000 Subject: [PATCH 23/28] fix bad conditional --- test/gpu/fuse_mlir.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index e50dea7d9a5..c46e462b99a 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1849,9 +1849,12 @@ TEST_CASE(dot_add_dot_abc) auto device_name = migraphx::gpu::get_device_name(); bool is_navi = migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); - if(is_navi) + if(is_navi){ EXPECT(program_str.find("geg") == std::string::npos); - EXPECT(program_str.find("geg") != std::string::npos); + } + else { + EXPECT(program_str.find("geg") != std::string::npos); + } } TEST_CASE(dot_add_dot_cab) From 376747c80ece4383b8243c65a4b8596bcc0fc128 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 24 Oct 2025 07:20:00 +0000 Subject: [PATCH 24/28] ignore test if flag is not enabled --- test/gpu/fuse_mlir.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index c46e462b99a..1d95ad4bb9e 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -1849,12 +1849,11 @@ TEST_CASE(dot_add_dot_abc) auto device_name = migraphx::gpu::get_device_name(); bool is_navi = migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); - if(is_navi){ + // fusion should not happen if the device is navi or the fusion flag is disabled + if(is_navi or not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) EXPECT(program_str.find("geg") == std::string::npos); - } - else { + else EXPECT(program_str.find("geg") != std::string::npos); - } } TEST_CASE(dot_add_dot_cab) From 15b33e67c92e9e39095f1797e81895ee26d52249 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 29 Oct 2025 13:59:11 +0000 Subject: [PATCH 25/28] use initializer instead of bool parameter for enabling multi out g+g --- test/gpu/fuse_mlir.cpp | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 1d95ad4bb9e..1758b929da0 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -52,13 +53,10 @@ struct non_mlir_op } }; -static void run_pass(migraphx::program& p, bool enable_geg_multi_user = false) +static void run_pass(migraphx::program& p, migraphx::gpu::fuse_mlir fm = {}) { - migraphx::run_passes( - p, - {migraphx::gpu::fuse_mlir{.enable_extra = true, - .enable_geg_multi_out_intermediates = enable_geg_multi_user}, - migraphx::dead_code_elimination{}}); + fm.enable_extra = true; + migraphx::run_passes(p, {fm, migraphx::dead_code_elimination{}}); } template @@ -1838,6 +1836,7 @@ TEST_CASE(dot_add_dot_abc) auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2} mm->add_return({dot2}); } + save(p1, "testest.mgx"); run_pass(p1); // ensure "geg" is present. Earlier tests ensure the fusion is correct. This is just to ensure // it happens. @@ -2107,7 +2106,7 @@ TEST_CASE(dot_add_multi_user_dot) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2157,7 +2156,7 @@ TEST_CASE(dot_add_multi_user_dot_with_transpose) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2209,7 +2208,7 @@ TEST_CASE(dot_add_multi_user_dot_two_externals) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, external_t1, external_t2}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2265,7 +2264,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2321,7 +2320,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2381,7 +2380,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2473,7 +2472,7 @@ TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({add, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2565,7 +2564,7 @@ TEST_CASE(dot_pw_multi_user_dot) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); mm->add_return({elemwise, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2615,7 +2614,7 @@ TEST_CASE(dot_multi_user_add_dot) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); mm->add_return({dot2, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -2663,7 +2662,7 @@ TEST_CASE(dot_add_dot_both_multi_user) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); mm->add_return({add, dot2, transpose}); } - run_pass(p1, true); + run_pass(p1, {.enable_geg_multi_out_intermediates = true}); migraphx::program p2; { auto* mm = p2.get_main_module(); From b58cb2382220ceb432a996a7aa450a7405f93e1a Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 29 Oct 2025 15:40:28 +0000 Subject: [PATCH 26/28] more sort tests; rename sort; more navi geg tests; slight optimization of moving sort into branch for multi-out case instead of outside the conditional --- src/include/migraphx/module.hpp | 2 +- src/module.cpp | 6 +- src/targets/gpu/fuse_mlir.cpp | 22 +++-- test/gpu/fuse_mlir.cpp | 50 ++++++++-- test/module_test.cpp | 167 +++++++++++++++++++++++++++++++- 5 files changed, 223 insertions(+), 24 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index ec109d98b6d..d68b2683e65 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -339,7 +339,7 @@ struct MIGRAPHX_EXPORT module ins_dep_map calc_implicit_deps() const; void repeat_while_changes(std::size_t n, const std::function& f); - void localized_sort(instruction_ref start_ins, instruction_ref end_ins); + void hoist_external_inputs(instruction_ref start_ins, instruction_ref end_ins); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y); diff --git a/src/module.cpp b/src/module.cpp index 7b00c596030..4838d241904 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1600,9 +1600,9 @@ void module::repeat_while_changes(std::size_t n, const std::function& f) } } -// For topologically sorting a region in a module, canonically, such that the -// dependent chain between the two input instructions is last -void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) +// Hoists external inputs (instructions not in the dependency chain between start_ins and end_ins) +// to before start_ins, while preserving topological order +void module::hoist_external_inputs(instruction_ref start_ins, instruction_ref end_ins) { // get the chain of instructions between start_ins and end_ins, inclusive auto fusion_ins = find_instructions_between(start_ins, end_ins, this); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 3583b33ee5b..39823b59576 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -913,17 +913,16 @@ struct find_mlir_fused_geg_ops mm->add_return(return_vals); auto inputs = find_inputs(map_ins, &mpm.get_module(), mm); - // sort fusion section of module such that any external inputs are moved before the fusion - // so that we can safely place the fused mod in the multi-out case at the beginning of the - // chain - mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins); - auto fused_ins = - mpm.get_module().insert_instruction(first_gemm_ins, - mlir_op{second_gemm_ins->get_operator()}, - mlir_contiguous(mpm, inputs), - {mm}); if(first_gemm_has_multi_outs or elemwise_has_multi_outs) { + // hoist external inputs before the fusion so that we can safely place the fused mod + // in the multi-out case at the beginning of the chain + mpm.get_module().hoist_external_inputs(first_gemm_ins, second_gemm_ins); + auto fused_ins = + mpm.get_module().insert_instruction(first_gemm_ins, + mlir_op{second_gemm_ins->get_operator()}, + mlir_contiguous(mpm, inputs), + {mm}); std::size_t output_idx = 0; if(elemwise_has_multi_outs) { @@ -946,6 +945,11 @@ struct find_mlir_fused_geg_ops else { // simple single output case + auto fused_ins = + mpm.get_module().insert_instruction(first_gemm_ins, + mlir_op{second_gemm_ins->get_operator()}, + mlir_contiguous(mpm, inputs), + {mm}); mpm.get_module().replace_instruction(second_gemm_ins, fused_ins); } } diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 1758b929da0..7a18e98988b 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -24,13 +24,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include #include @@ -55,6 +55,8 @@ struct non_mlir_op static void run_pass(migraphx::program& p, migraphx::gpu::fuse_mlir fm = {}) { + static migraphx::gpu::context ctx; + fm.ctx = &ctx; fm.enable_extra = true; migraphx::run_passes(p, {fm, migraphx::dead_code_elimination{}}); } @@ -1816,13 +1818,13 @@ TEST_CASE(dot_add_dot) EXPECT(p1.sort() == p2.sort()); } -TEST_CASE(dot_add_dot_abc) +TEST_CASE(dot_add_dot_abc_f32) // MLIR currently only supports (A*B)*C GEG patterns { - migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::half_type, {3, 4}}; - migraphx::shape s3{migraphx::shape::half_type, {2, 4}}; - migraphx::shape s4{migraphx::shape::half_type, {4, 2}}; + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -1836,7 +1838,6 @@ TEST_CASE(dot_add_dot_abc) auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2} mm->add_return({dot2}); } - save(p1, "testest.mgx"); run_pass(p1); // ensure "geg" is present. Earlier tests ensure the fusion is correct. This is just to ensure // it happens. @@ -1844,7 +1845,7 @@ TEST_CASE(dot_add_dot_abc) ss << p1; std::string program_str = ss.str(); - // regardless if the matmul is correctly oriented, geg should not happen on navi + // regardless if the matmul is correctly oriented, f32 geg should not happen on navi auto device_name = migraphx::gpu::get_device_name(); bool is_navi = migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12"); @@ -1855,6 +1856,39 @@ TEST_CASE(dot_add_dot_abc) EXPECT(program_str.find("geg") != std::string::npos); } +TEST_CASE(dot_add_dot_abc_fp16) +// MLIR currently only supports (A*B)*C GEG patterns +{ + migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::half_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::half_type, {4, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} + auto add = + add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2} + mm->add_return({dot2}); + } + run_pass(p1); + std::stringstream ss; + ss << p1; + std::string program_str = ss.str(); + + // ensure "geg" is present if the fusion flag is enabled; type is fp16 so it should + // run regardless of if navi + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + EXPECT(program_str.find("geg") == std::string::npos); + else + EXPECT(program_str.find("geg") != std::string::npos); +} + TEST_CASE(dot_add_dot_cab) // MLIR currently does not support C*(A*B) GEG patterns { diff --git a/test/module_test.cpp b/test/module_test.cpp index 270bfcd63e4..40ba46f1d7c 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1369,7 +1369,7 @@ TEST_CASE(pathological_dfs_graph_sort) EXPECT(is_sorted(m)); } -TEST_CASE(localized_sort) +TEST_CASE(hoist_external_inputs) { migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; @@ -1394,8 +1394,8 @@ TEST_CASE(localized_sort) mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); mm->add_return({add, transpose}); - // Perform localized sort between dot1 and dot2 - mm->localized_sort(dot1, dot2); + // Hoist external inputs between dot1 and dot2 + mm->hoist_external_inputs(dot1, dot2); // Verify the module is still topologically sorted overall EXPECT(is_sorted(*mm)); @@ -1415,4 +1415,165 @@ TEST_CASE(localized_sort) EXPECT(std::distance(mm->begin(), dot2) < std::distance(mm->begin(), transpose)); } +TEST_CASE(hoist_external_inputs_no_movement_needed) +{ + // Test where external dependencies are already before the fusion chain + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s3); + + // External operations already positioned before the fusion chain + auto external1 = add_pointwise(p, "main:pointwise0", {a}, single_pointwise("relu")); + auto external2 = add_pointwise(p, "main:pointwise1", {b}, single_pointwise("tanh")); + auto external3 = add_pointwise(p, "main:pointwise2", {c}, single_pointwise("tanh")); + + // Fusion chain + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), external1, external2); + auto add = add_pointwise(p, "main:pointwise3", {dot1, external3}, single_pointwise("add")); + + mm->add_return({add}); + + // Record positions before hoist_external_inputs + auto dot1_pos_before = std::distance(mm->begin(), dot1); + auto add_pos_before = std::distance(mm->begin(), add); + + mm->hoist_external_inputs(dot1, add); + + // Verify positions haven't changed (nothing needed to move) + EXPECT(std::distance(mm->begin(), dot1) == dot1_pos_before); + EXPECT(std::distance(mm->begin(), add) == add_pos_before); + EXPECT(is_sorted(*mm)); +} + +TEST_CASE(hoist_external_inputs_multiple_external_branches) +{ + // Test with multiple independent external branches that need to be moved + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {4, 3}}; + migraphx::shape s4{migraphx::shape::float_type, {2, 4}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s4); + auto d = mm->add_parameter("d", s3); + + // Start of fusion chain + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 4} + + // External branch 1 + auto ext1 = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {2, 4} + auto ext2 = add_pointwise(p, "main:pointwise1", {ext1}, single_pointwise("neg")); // {2, 4} + + // External branch 2 + auto ext3 = add_pointwise(p, "main:pointwise2", {d}, single_pointwise("tanh")); // {4, 3} + auto ext4 = add_pointwise(p, "main:pointwise3", {ext3}, single_pointwise("abs")); // {4, 3} + + // Continue fusion chain using external branches + auto add1 = add_pointwise(p, "main:pointwise4", {dot1, ext2}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add1, ext4); // {2, 4} x {4, 3} = {2, 3} + + mm->add_return({dot2}); + + mm->hoist_external_inputs(dot1, dot2); + + EXPECT(is_sorted(*mm)); + + // Verify all external operations moved before dot1 + EXPECT(std::distance(mm->begin(), ext1) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), ext2) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), ext3) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), ext4) < std::distance(mm->begin(), dot1)); + + // Verify fusion chain order preserved + EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), add1)); + EXPECT(std::distance(mm->begin(), add1) < std::distance(mm->begin(), dot2)); +} + +TEST_CASE(hoist_external_inputs_long_fusion_chain) +{ + // Test with a longer fusion chain + migraphx::shape s{migraphx::shape::float_type, {4, 4}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + // Long fusion chain + auto op1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + + // External operations interspersed + auto ext1 = add_pointwise(p, "main:pointwise0", {d}, single_pointwise("exp")); + + auto op2 = add_pointwise(p, "main:pointwise1", {op1, c}, single_pointwise("add")); + + auto ext2 = add_pointwise(p, "main:pointwise2", {ext1}, single_pointwise("log")); + + auto op3 = add_pointwise(p, "main:pointwise3", {op2}, single_pointwise("relu")); + + auto ext3 = add_pointwise(p, "main:pointwise4", {ext2}, single_pointwise("abs")); + + auto op4 = add_pointwise(p, "main:pointwise5", {op3}, single_pointwise("tanh")); + auto op5 = mm->add_instruction(migraphx::make_op("dot"), op4, ext3); + + mm->add_return({op5}); + + mm->hoist_external_inputs(op1, op5); + + EXPECT(is_sorted(*mm)); + + // All external operations moved before op1 + EXPECT(std::distance(mm->begin(), ext1) < std::distance(mm->begin(), op1)); + EXPECT(std::distance(mm->begin(), ext2) < std::distance(mm->begin(), op1)); + EXPECT(std::distance(mm->begin(), ext3) < std::distance(mm->begin(), op1)); + + // Fusion chain order preserved + EXPECT(std::distance(mm->begin(), op1) < std::distance(mm->begin(), op2)); + EXPECT(std::distance(mm->begin(), op2) < std::distance(mm->begin(), op3)); + EXPECT(std::distance(mm->begin(), op3) < std::distance(mm->begin(), op4)); + EXPECT(std::distance(mm->begin(), op4) < std::distance(mm->begin(), op5)); +} + +TEST_CASE(hoist_external_inputs_adjacent_instructions) +{ + // Test edge case where start and end are adjacent + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {4, 5}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s1); // {2, 3} + auto b = mm->add_parameter("b", s2); // {3, 4} + auto c = mm->add_parameter("c", s3); // {4, 5} + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 3} x {3, 4} = {2, 4} + + // External operation between dot1 and dot2 + auto external = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {4, 5} + + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), dot1, external); // {2, 4} x {4, 5} = {2, 5} + + mm->add_return({dot2}); + + mm->hoist_external_inputs(dot1, dot2); + + EXPECT(is_sorted(*mm)); + + // External should be moved before dot1 + EXPECT(std::distance(mm->begin(), external) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), dot2)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 9316e0a7772b27e3cd5ac6df84e9b58a0cbe45fe Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 3 Nov 2025 10:31:07 -0600 Subject: [PATCH 27/28] format --- test/gpu/fuse_mlir.cpp | 2 +- test/module_test.cpp | 27 +++++++++++++++------------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 7a18e98988b..7419e615ce7 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -56,7 +56,7 @@ struct non_mlir_op static void run_pass(migraphx::program& p, migraphx::gpu::fuse_mlir fm = {}) { static migraphx::gpu::context ctx; - fm.ctx = &ctx; + fm.ctx = &ctx; fm.enable_extra = true; migraphx::run_passes(p, {fm, migraphx::dead_code_elimination{}}); } diff --git a/test/module_test.cpp b/test/module_test.cpp index 40ba46f1d7c..87ab9019e13 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1435,13 +1435,13 @@ TEST_CASE(hoist_external_inputs_no_movement_needed) // Fusion chain auto dot1 = mm->add_instruction(migraphx::make_op("dot"), external1, external2); - auto add = add_pointwise(p, "main:pointwise3", {dot1, external3}, single_pointwise("add")); + auto add = add_pointwise(p, "main:pointwise3", {dot1, external3}, single_pointwise("add")); mm->add_return({add}); // Record positions before hoist_external_inputs auto dot1_pos_before = std::distance(mm->begin(), dot1); - auto add_pos_before = std::distance(mm->begin(), add); + auto add_pos_before = std::distance(mm->begin(), add); mm->hoist_external_inputs(dot1, add); @@ -1470,16 +1470,18 @@ TEST_CASE(hoist_external_inputs_multiple_external_branches) auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 4} // External branch 1 - auto ext1 = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {2, 4} + auto ext1 = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {2, 4} auto ext2 = add_pointwise(p, "main:pointwise1", {ext1}, single_pointwise("neg")); // {2, 4} // External branch 2 - auto ext3 = add_pointwise(p, "main:pointwise2", {d}, single_pointwise("tanh")); // {4, 3} + auto ext3 = add_pointwise(p, "main:pointwise2", {d}, single_pointwise("tanh")); // {4, 3} auto ext4 = add_pointwise(p, "main:pointwise3", {ext3}, single_pointwise("abs")); // {4, 3} // Continue fusion chain using external branches - auto add1 = add_pointwise(p, "main:pointwise4", {dot1, ext2}, single_pointwise("add")); // {2, 4} - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add1, ext4); // {2, 4} x {4, 3} = {2, 3} + auto add1 = + add_pointwise(p, "main:pointwise4", {dot1, ext2}, single_pointwise("add")); // {2, 4} + auto dot2 = + mm->add_instruction(migraphx::make_op("dot"), add1, ext4); // {2, 4} x {4, 3} = {2, 3} mm->add_return({dot2}); @@ -1554,16 +1556,17 @@ TEST_CASE(hoist_external_inputs_adjacent_instructions) migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", s1); // {2, 3} - auto b = mm->add_parameter("b", s2); // {3, 4} - auto c = mm->add_parameter("c", s3); // {4, 5} + auto a = mm->add_parameter("a", s1); // {2, 3} + auto b = mm->add_parameter("b", s2); // {3, 4} + auto c = mm->add_parameter("c", s3); // {4, 5} - auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 3} x {3, 4} = {2, 4} + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 3} x {3, 4} = {2, 4} // External operation between dot1 and dot2 - auto external = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {4, 5} + auto external = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {4, 5} - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), dot1, external); // {2, 4} x {4, 5} = {2, 5} + auto dot2 = + mm->add_instruction(migraphx::make_op("dot"), dot1, external); // {2, 4} x {4, 5} = {2, 5} mm->add_return({dot2}); From bc9d3f9c1722891a14a0d428ffb7bda28a1a2ff2 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 3 Nov 2025 13:39:39 -0600 Subject: [PATCH 28/28] fix placement of fused module --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 39823b59576..6c28043bc9c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -946,7 +946,7 @@ struct find_mlir_fused_geg_ops { // simple single output case auto fused_ins = - mpm.get_module().insert_instruction(first_gemm_ins, + mpm.get_module().insert_instruction(second_gemm_ins, mlir_op{second_gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm});