Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ce00b89
fix localized sort
bdevorem Oct 15, 2025
92aead8
unsupported fusions and clang format
bdevorem Oct 15, 2025
55e0868
comment
bdevorem Oct 15, 2025
cd069bc
format
bdevorem Oct 15, 2025
93e9607
format
bdevorem Oct 15, 2025
bac1bca
ci failures
bdevorem Oct 15, 2025
37c31cd
change geg fuse_mlir tests to be f16 to avoid needing to worry about …
bdevorem Oct 15, 2025
89fb484
comment
bdevorem Oct 15, 2025
b3dd3bf
style issues
bdevorem Oct 15, 2025
bcf4173
format
bdevorem Oct 15, 2025
927ec55
format again
bdevorem Oct 15, 2025
92b4435
f32 disabled for all navis
bdevorem Oct 21, 2025
cbc483e
force ordering of gemm1 inputs in geg
bdevorem Oct 21, 2025
5cf6f93
toggle off multi out support for intermediates in g+g fusion
bdevorem Oct 21, 2025
23585af
format
bdevorem Oct 21, 2025
cc75a89
format
bdevorem Oct 22, 2025
8ea8264
make flag in pass class for toggling multi out intermediates support
bdevorem Oct 22, 2025
e6d7b47
format/tidy
bdevorem Oct 22, 2025
dcf75fe
edit test ensuring geg fusion happens on correctly oriented gemms doe…
bdevorem Oct 22, 2025
461bb73
format
bdevorem Oct 22, 2025
2346b17
format
bdevorem Oct 22, 2025
ef3aa00
format
bdevorem Oct 22, 2025
d8df553
fix bad conditional
bdevorem Oct 24, 2025
376747c
ignore test if flag is not enabled
bdevorem Oct 24, 2025
15b33e6
use initializer instead of bool parameter for enabling multi out g+g
bdevorem Oct 29, 2025
b58cb23
more sort tests; rename sort; more navi geg tests; slight optimizatio…
bdevorem Oct 29, 2025
9316e0a
format
bdevorem Nov 3, 2025
bc9d3f9
fix placement of fused module
bdevorem Nov 3, 2025
fadef65
Merge branch 'develop' into bdevorem/geg_heuristics
bdevorem Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ struct MIGRAPHX_EXPORT module
ins_dep_map calc_implicit_deps() const;

void repeat_while_changes(std::size_t n, const std::function<void()>& f);
void localized_sort(instruction_ref start_ins, instruction_ref end_ins);
void hoist_external_inputs(instruction_ref start_ins, instruction_ref end_ins);

MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y);
Expand Down
30 changes: 24 additions & 6 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,9 +1600,9 @@ void module::repeat_while_changes(std::size_t n, const std::function<void()>& f)
}
}

// For topologically sorting a region in a module, canonically, such that the
// dependent chain between the two input instructions is last
void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins)
// Hoists external inputs (instructions not in the dependency chain between start_ins and end_ins)
// to before start_ins, while preserving topological order
void module::hoist_external_inputs(instruction_ref start_ins, instruction_ref end_ins)
{
// get the chain of instructions between start_ins and end_ins, inclusive
auto fusion_ins = find_instructions_between(start_ins, end_ins, this);
Expand All @@ -1614,15 +1614,33 @@ void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins)
{
if(fusion_ins.count(it) == 0)
{
auto next = std::next(it); // move_instruction updates the iterator
this->move_instruction(it, start_ins);
it = next;
// only move if none of its inputs are after start_ins
bool has_input_in_range =
std::any_of(it->inputs().begin(), it->inputs().end(), [&](instruction_ref input) {
if(not has_instruction(input))
return false;
// verify: start_ins < input < it
return std::find(std::next(start_ins), it, input) != it;
});

if(has_input_in_range)
{
// input is after start_ins, meaning can't move this instruction
++it;
}
else
{
auto next = std::next(it);
this->move_instruction(it, start_ins);
it = next;
}
}
else
{
++it;
}
}
assert(this->validate() == this->end());
}

bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); }
Expand Down
66 changes: 49 additions & 17 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,23 @@ struct find_mlir_fused_geg_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
std::string gfx_name;
bool enable_geg_multi_out_intermediates = false;

// check if individual GEMM meets architecture requirements
bool is_gemm_supported(instruction_ref ins) const
{
// on navis, wmma doesn't support fp32, so skip fp32 GEMMs
// one gemm being f32 is sufficient to turn off this fusion
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciate the comment here

if(starts_with(gfx_name, "gfx11") or starts_with(gfx_name, "gfx12"))
{
if(ins->get_shape().type() == shape::type_t::float_type)
{
return false;
}
}
return true;
}

/*
* Matches:
Expand All @@ -797,11 +814,15 @@ struct find_mlir_fused_geg_ops
*/
auto matcher() const
{
auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode))
.bind("first_gemm_based_op");
auto gemm_supported = match::make_basic_pred_matcher(
[this](instruction_ref ins) { return is_gemm_supported(ins); });

auto first_dot_or_conv =
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode))(gemm_supported)
.bind("first_gemm_based_op");
auto elemwise =
mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise");
return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise))
return is_mlir_dot(dot_mode)(gemm_supported)(match::arg(0)(elemwise))
.bind("second_gemm_op");
}

Expand All @@ -820,10 +841,11 @@ struct find_mlir_fused_geg_ops
}))
return;

// only one input to second_gemm should depend on elemwise
// only one input to second_gemm should depend on elemwise or first_gemm
auto second_gemm_inputs = second_gemm_ins->inputs();
if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) {
return i != elemwise_ins and reaches(elemwise_ins, i);
return (i != elemwise_ins and reaches(elemwise_ins, i)) and
(i != first_gemm_ins and reaches(first_gemm_ins, i));
}))
return;

Expand All @@ -837,6 +859,11 @@ struct find_mlir_fused_geg_ops
bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1;
bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1;

// if we have multi outs for either of the intermediates, check if this is supported first
if((first_gemm_has_multi_outs or elemwise_has_multi_outs) and
not enable_geg_multi_out_intermediates)
return;

// add the first gemm to the module
std::vector<instruction_ref> first_gemm_mapped_inputs;
first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size());
Expand Down Expand Up @@ -886,19 +913,16 @@ struct find_mlir_fused_geg_ops
mm->add_return(return_vals);
auto inputs = find_inputs(map_ins, &mpm.get_module(), mm);

// sort fusion section of module such that any external inputs are moved before the fusion
// so that we can safely place the fused mod in the multi-out case at the beginning of the
// chain
mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins);

auto fused_ins =
mpm.get_module().insert_instruction(first_gemm_ins,
mlir_op{second_gemm_ins->get_operator()},
mlir_contiguous(mpm, inputs),
{mm});

if(first_gemm_has_multi_outs or elemwise_has_multi_outs)
{
// hoist external inputs before the fusion so that we can safely place the fused mod
// in the multi-out case at the beginning of the chain
mpm.get_module().hoist_external_inputs(first_gemm_ins, second_gemm_ins);
auto fused_ins =
mpm.get_module().insert_instruction(first_gemm_ins,
mlir_op{second_gemm_ins->get_operator()},
mlir_contiguous(mpm, inputs),
{mm});
std::size_t output_idx = 0;
if(elemwise_has_multi_outs)
{
Expand All @@ -921,6 +945,11 @@ struct find_mlir_fused_geg_ops
else
{
// simple single output case
auto fused_ins =
mpm.get_module().insert_instruction(second_gemm_ins,
mlir_op{second_gemm_ins->get_operator()},
mlir_contiguous(mpm, inputs),
{mm});
mpm.get_module().replace_instruction(second_gemm_ins, fused_ins);
}
}
Expand Down Expand Up @@ -1452,7 +1481,10 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches(
mpm,
find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});
.dot_mode = get_mode("fused_dot", mlir_mode::fast),
.gfx_name = device_name,
.enable_geg_multi_out_intermediates =
enable_geg_multi_out_intermediates});
mpm.run_pass(dead_code_elimination{});
}

Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ namespace gpu {

MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled(context* ctx);
MIGRAPHX_GPU_EXPORT bool mlir_geg_multi_user_intermediates_supported();

struct MIGRAPHX_GPU_EXPORT fuse_mlir
{
context* ctx = nullptr;
bool enable_extra = false;
bool enable_geg_multi_out_intermediates = false;
std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const;
};
Expand Down
Loading
Loading