-
Notifications
You must be signed in to change notification settings - Fork 112
g+g disabling for navi/f32, sort fix, force gemm0 to be first input of gemm1, toggle off intermediate outs #4384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
ce00b89
fix localized sort
bdevorem 92aead8
unsupported fusions and clang format
bdevorem 55e0868
comment
bdevorem cd069bc
format
bdevorem 93e9607
format
bdevorem bac1bca
ci failures
bdevorem 37c31cd
change geg fuse_mlir tests to be f16 to avoid needing to worry about …
bdevorem 89fb484
comment
bdevorem b3dd3bf
style issues
bdevorem bcf4173
format
bdevorem 927ec55
format again
bdevorem 92b4435
f32 disabled for all navis
bdevorem cbc483e
force ordering of gemm1 inputs in geg
bdevorem 5cf6f93
toggle off multi out support for intermediates in g+g fusion
bdevorem 23585af
format
bdevorem cc75a89
format
bdevorem 8ea8264
make flag in pass class for toggling multi out intermediates support
bdevorem e6d7b47
format/tidy
bdevorem dcf75fe
edit test ensuring geg fusion happens on correctly oriented gemms doe…
bdevorem 461bb73
format
bdevorem 2346b17
format
bdevorem ef3aa00
format
bdevorem d8df553
fix bad conditional
bdevorem 376747c
ignore test if flag is not enabled
bdevorem 15b33e6
use initializer instead of bool parameter for enabling multi out g+g
bdevorem b58cb23
more sort tests; rename sort; more navi geg tests; slight optimizatio…
bdevorem 9316e0a
format
bdevorem bc9d3f9
fix placement of fused module
bdevorem fadef65
Merge branch 'develop' into bdevorem/geg_heuristics
bdevorem File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -788,6 +788,23 @@ struct find_mlir_fused_geg_ops | |
| { | ||
| mlir_mode conv_mode = mlir_mode::none; | ||
| mlir_mode dot_mode = mlir_mode::none; | ||
| std::string gfx_name; | ||
| bool enable_geg_multi_out_intermediates = false; | ||
|
|
||
| // check if individual GEMM meets architecture requirements | ||
| bool is_gemm_supported(instruction_ref ins) const | ||
| { | ||
| // on navis, wmma doesn't support fp32, so skip fp32 GEMMs | ||
| // one gemm being f32 is sufficient to turn off this fusion | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. appreciate the comment here |
||
| if(starts_with(gfx_name, "gfx11") or starts_with(gfx_name, "gfx12")) | ||
| { | ||
| if(ins->get_shape().type() == shape::type_t::float_type) | ||
| { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| /* | ||
| * Matches: | ||
|
|
@@ -797,11 +814,15 @@ struct find_mlir_fused_geg_ops | |
| */ | ||
| auto matcher() const | ||
| { | ||
| auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)) | ||
| .bind("first_gemm_based_op"); | ||
| auto gemm_supported = match::make_basic_pred_matcher( | ||
| [this](instruction_ref ins) { return is_gemm_supported(ins); }); | ||
|
|
||
| auto first_dot_or_conv = | ||
| match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode))(gemm_supported) | ||
| .bind("first_gemm_based_op"); | ||
bdevorem marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto elemwise = | ||
| mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise"); | ||
| return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise)) | ||
| return is_mlir_dot(dot_mode)(gemm_supported)(match::arg(0)(elemwise)) | ||
| .bind("second_gemm_op"); | ||
| } | ||
|
|
||
|
|
@@ -820,10 +841,11 @@ struct find_mlir_fused_geg_ops | |
| })) | ||
| return; | ||
|
|
||
| // only one input to second_gemm should depend on elemwise | ||
| // only one input to second_gemm should depend on elemwise or first_gemm | ||
| auto second_gemm_inputs = second_gemm_ins->inputs(); | ||
| if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) { | ||
| return i != elemwise_ins and reaches(elemwise_ins, i); | ||
| return (i != elemwise_ins and reaches(elemwise_ins, i)) and | ||
| (i != first_gemm_ins and reaches(first_gemm_ins, i)); | ||
| })) | ||
| return; | ||
|
|
||
|
|
@@ -837,6 +859,11 @@ struct find_mlir_fused_geg_ops | |
| bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1; | ||
| bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1; | ||
|
|
||
| // if we have multi outs for either of the intermediates, check if this is supported first | ||
| if((first_gemm_has_multi_outs or elemwise_has_multi_outs) and | ||
| not enable_geg_multi_out_intermediates) | ||
| return; | ||
|
|
||
| // add the first gemm to the module | ||
| std::vector<instruction_ref> first_gemm_mapped_inputs; | ||
| first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size()); | ||
|
|
@@ -886,19 +913,16 @@ struct find_mlir_fused_geg_ops | |
| mm->add_return(return_vals); | ||
| auto inputs = find_inputs(map_ins, &mpm.get_module(), mm); | ||
|
|
||
| // sort fusion section of module such that any external inputs are moved before the fusion | ||
| // so that we can safely place the fused mod in the multi-out case at the beginning of the | ||
| // chain | ||
| mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins); | ||
|
|
||
| auto fused_ins = | ||
| mpm.get_module().insert_instruction(first_gemm_ins, | ||
| mlir_op{second_gemm_ins->get_operator()}, | ||
| mlir_contiguous(mpm, inputs), | ||
| {mm}); | ||
|
|
||
| if(first_gemm_has_multi_outs or elemwise_has_multi_outs) | ||
| { | ||
| // hoist external inputs before the fusion so that we can safely place the fused mod | ||
| // in the multi-out case at the beginning of the chain | ||
| mpm.get_module().hoist_external_inputs(first_gemm_ins, second_gemm_ins); | ||
| auto fused_ins = | ||
| mpm.get_module().insert_instruction(first_gemm_ins, | ||
| mlir_op{second_gemm_ins->get_operator()}, | ||
| mlir_contiguous(mpm, inputs), | ||
| {mm}); | ||
| std::size_t output_idx = 0; | ||
| if(elemwise_has_multi_outs) | ||
| { | ||
|
|
@@ -921,6 +945,11 @@ struct find_mlir_fused_geg_ops | |
| else | ||
| { | ||
| // simple single output case | ||
| auto fused_ins = | ||
| mpm.get_module().insert_instruction(second_gemm_ins, | ||
| mlir_op{second_gemm_ins->get_operator()}, | ||
| mlir_contiguous(mpm, inputs), | ||
| {mm}); | ||
| mpm.get_module().replace_instruction(second_gemm_ins, fused_ins); | ||
| } | ||
| } | ||
|
|
@@ -1452,7 +1481,10 @@ void fuse_mlir::apply(module_pass_manager& mpm) const | |
| match::find_matches( | ||
| mpm, | ||
| find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), | ||
| .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); | ||
| .dot_mode = get_mode("fused_dot", mlir_mode::fast), | ||
| .gfx_name = device_name, | ||
| .enable_geg_multi_out_intermediates = | ||
| enable_geg_multi_out_intermediates}); | ||
| mpm.run_pass(dead_code_elimination{}); | ||
| } | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.