Skip to content

Commit b58cb23

Browse files
committed
more sort tests; rename sort; more navi geg tests; slight optimization of moving sort into branch for multi-out case instead of outside the conditional
1 parent 15b33e6 commit b58cb23

File tree

5 files changed

+223
-24
lines changed

5 files changed

+223
-24
lines changed

src/include/migraphx/module.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ struct MIGRAPHX_EXPORT module
339339
ins_dep_map calc_implicit_deps() const;
340340

341341
void repeat_while_changes(std::size_t n, const std::function<void()>& f);
342-
void localized_sort(instruction_ref start_ins, instruction_ref end_ins);
342+
void hoist_external_inputs(instruction_ref start_ins, instruction_ref end_ins);
343343

344344
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
345345
MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y);

src/module.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,9 +1600,9 @@ void module::repeat_while_changes(std::size_t n, const std::function<void()>& f)
16001600
}
16011601
}
16021602

1603-
// For topologically sorting a region in a module, canonically, such that the
1604-
// dependent chain between the two input instructions is last
1605-
void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins)
1603+
// Hoists external inputs (instructions not in the dependency chain between start_ins and end_ins)
1604+
// to before start_ins, while preserving topological order
1605+
void module::hoist_external_inputs(instruction_ref start_ins, instruction_ref end_ins)
16061606
{
16071607
// get the chain of instructions between start_ins and end_ins, inclusive
16081608
auto fusion_ins = find_instructions_between(start_ins, end_ins, this);

src/targets/gpu/fuse_mlir.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -913,17 +913,16 @@ struct find_mlir_fused_geg_ops
913913
mm->add_return(return_vals);
914914
auto inputs = find_inputs(map_ins, &mpm.get_module(), mm);
915915

916-
// sort fusion section of module such that any external inputs are moved before the fusion
917-
// so that we can safely place the fused mod in the multi-out case at the beginning of the
918-
// chain
919-
mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins);
920-
auto fused_ins =
921-
mpm.get_module().insert_instruction(first_gemm_ins,
922-
mlir_op{second_gemm_ins->get_operator()},
923-
mlir_contiguous(mpm, inputs),
924-
{mm});
925916
if(first_gemm_has_multi_outs or elemwise_has_multi_outs)
926917
{
918+
// hoist external inputs before the fusion so that we can safely place the fused mod
919+
// in the multi-out case at the beginning of the chain
920+
mpm.get_module().hoist_external_inputs(first_gemm_ins, second_gemm_ins);
921+
auto fused_ins =
922+
mpm.get_module().insert_instruction(first_gemm_ins,
923+
mlir_op{second_gemm_ins->get_operator()},
924+
mlir_contiguous(mpm, inputs),
925+
{mm});
927926
std::size_t output_idx = 0;
928927
if(elemwise_has_multi_outs)
929928
{
@@ -946,6 +945,11 @@ struct find_mlir_fused_geg_ops
946945
else
947946
{
948947
// simple single output case
948+
auto fused_ins =
949+
mpm.get_module().insert_instruction(first_gemm_ins,
950+
mlir_op{second_gemm_ins->get_operator()},
951+
mlir_contiguous(mpm, inputs),
952+
{mm});
949953
mpm.get_module().replace_instruction(second_gemm_ins, fused_ins);
950954
}
951955
}

test/gpu/fuse_mlir.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
#include <migraphx/generate.hpp>
2525
#include <migraphx/dead_code_elimination.hpp>
2626
#include <migraphx/gpu/fuse_mlir.hpp>
27+
#include <migraphx/gpu/context.hpp>
2728
#include <migraphx/instruction.hpp>
2829
#include <migraphx/pass_manager.hpp>
2930
#include <migraphx/op/common.hpp>
3031
#include <migraphx/program.hpp>
3132
#include <migraphx/make_op.hpp>
3233
#include <migraphx/param_utils.hpp>
33-
#include <migraphx/load_save.hpp>
3434
#include <basic_ops.hpp>
3535
#include <group.hpp>
3636
#include <test.hpp>
@@ -55,6 +55,8 @@ struct non_mlir_op
5555

5656
static void run_pass(migraphx::program& p, migraphx::gpu::fuse_mlir fm = {})
5757
{
58+
static migraphx::gpu::context ctx;
59+
fm.ctx = &ctx;
5860
fm.enable_extra = true;
5961
migraphx::run_passes(p, {fm, migraphx::dead_code_elimination{}});
6062
}
@@ -1816,13 +1818,13 @@ TEST_CASE(dot_add_dot)
18161818
EXPECT(p1.sort() == p2.sort());
18171819
}
18181820

1819-
TEST_CASE(dot_add_dot_abc)
1821+
TEST_CASE(dot_add_dot_abc_f32)
18201822
// MLIR currently only supports (A*B)*C GEG patterns
18211823
{
1822-
migraphx::shape s1{migraphx::shape::half_type, {2, 3}};
1823-
migraphx::shape s2{migraphx::shape::half_type, {3, 4}};
1824-
migraphx::shape s3{migraphx::shape::half_type, {2, 4}};
1825-
migraphx::shape s4{migraphx::shape::half_type, {4, 2}};
1824+
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
1825+
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
1826+
migraphx::shape s3{migraphx::shape::float_type, {2, 4}};
1827+
migraphx::shape s4{migraphx::shape::float_type, {4, 2}};
18261828
migraphx::program p1;
18271829
{
18281830
auto* mm = p1.get_main_module();
@@ -1836,15 +1838,14 @@ TEST_CASE(dot_add_dot_abc)
18361838
auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2}
18371839
mm->add_return({dot2});
18381840
}
1839-
save(p1, "testest.mgx");
18401841
run_pass(p1);
18411842
// ensure "geg" is present. Earlier tests ensure the fusion is correct. This is just to ensure
18421843
// it happens.
18431844
std::stringstream ss;
18441845
ss << p1;
18451846
std::string program_str = ss.str();
18461847

1847-
// regardless if the matmul is correctly oriented, geg should not happen on navi
1848+
// regardless if the matmul is correctly oriented, f32 geg should not happen on navi
18481849
auto device_name = migraphx::gpu::get_device_name();
18491850
bool is_navi =
18501851
migraphx::starts_with(device_name, "gfx11") or migraphx::starts_with(device_name, "gfx12");
@@ -1855,6 +1856,39 @@ TEST_CASE(dot_add_dot_abc)
18551856
EXPECT(program_str.find("geg") != std::string::npos);
18561857
}
18571858

1859+
TEST_CASE(dot_add_dot_abc_fp16)
1860+
// MLIR currently only supports (A*B)*C GEG patterns
1861+
{
1862+
migraphx::shape s1{migraphx::shape::half_type, {2, 3}};
1863+
migraphx::shape s2{migraphx::shape::half_type, {3, 4}};
1864+
migraphx::shape s3{migraphx::shape::half_type, {2, 4}};
1865+
migraphx::shape s4{migraphx::shape::half_type, {4, 2}};
1866+
migraphx::program p1;
1867+
{
1868+
auto* mm = p1.get_main_module();
1869+
auto a = mm->add_parameter("a", s1);
1870+
auto b = mm->add_parameter("b", s2);
1871+
auto x = mm->add_parameter("x", s3);
1872+
auto y = mm->add_parameter("y", s4);
1873+
auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4}
1874+
auto add =
1875+
add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2, 4}
1876+
auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {2, 4}*{4, 2} = {2, 2}
1877+
mm->add_return({dot2});
1878+
}
1879+
run_pass(p1);
1880+
std::stringstream ss;
1881+
ss << p1;
1882+
std::string program_str = ss.str();
1883+
1884+
// ensure "geg" is present if the fusion flag is enabled; type is fp16 so it should
1885+
// run regardless of if navi
1886+
if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}))
1887+
EXPECT(program_str.find("geg") == std::string::npos);
1888+
else
1889+
EXPECT(program_str.find("geg") != std::string::npos);
1890+
}
1891+
18581892
TEST_CASE(dot_add_dot_cab)
18591893
// MLIR currently does not support C*(A*B) GEG patterns
18601894
{

test/module_test.cpp

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ TEST_CASE(pathological_dfs_graph_sort)
13691369
EXPECT(is_sorted(m));
13701370
}
13711371

1372-
TEST_CASE(localized_sort)
1372+
TEST_CASE(hoist_external_inputs)
13731373
{
13741374
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
13751375
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
@@ -1394,8 +1394,8 @@ TEST_CASE(localized_sort)
13941394
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2);
13951395
mm->add_return({add, transpose});
13961396

1397-
// Perform localized sort between dot1 and dot2
1398-
mm->localized_sort(dot1, dot2);
1397+
// Hoist external inputs between dot1 and dot2
1398+
mm->hoist_external_inputs(dot1, dot2);
13991399

14001400
// Verify the module is still topologically sorted overall
14011401
EXPECT(is_sorted(*mm));
@@ -1415,4 +1415,165 @@ TEST_CASE(localized_sort)
14151415
EXPECT(std::distance(mm->begin(), dot2) < std::distance(mm->begin(), transpose));
14161416
}
14171417

1418+
TEST_CASE(hoist_external_inputs_no_movement_needed)
1419+
{
1420+
// Test where external dependencies are already before the fusion chain
1421+
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
1422+
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
1423+
migraphx::shape s3{migraphx::shape::float_type, {2, 4}};
1424+
migraphx::program p;
1425+
auto* mm = p.get_main_module();
1426+
1427+
auto a = mm->add_parameter("a", s1);
1428+
auto b = mm->add_parameter("b", s2);
1429+
auto c = mm->add_parameter("c", s3);
1430+
1431+
// External operations already positioned before the fusion chain
1432+
auto external1 = add_pointwise(p, "main:pointwise0", {a}, single_pointwise("relu"));
1433+
auto external2 = add_pointwise(p, "main:pointwise1", {b}, single_pointwise("tanh"));
1434+
auto external3 = add_pointwise(p, "main:pointwise2", {c}, single_pointwise("tanh"));
1435+
1436+
// Fusion chain
1437+
auto dot1 = mm->add_instruction(migraphx::make_op("dot"), external1, external2);
1438+
auto add = add_pointwise(p, "main:pointwise3", {dot1, external3}, single_pointwise("add"));
1439+
1440+
mm->add_return({add});
1441+
1442+
// Record positions before hoist_external_inputs
1443+
auto dot1_pos_before = std::distance(mm->begin(), dot1);
1444+
auto add_pos_before = std::distance(mm->begin(), add);
1445+
1446+
mm->hoist_external_inputs(dot1, add);
1447+
1448+
// Verify positions haven't changed (nothing needed to move)
1449+
EXPECT(std::distance(mm->begin(), dot1) == dot1_pos_before);
1450+
EXPECT(std::distance(mm->begin(), add) == add_pos_before);
1451+
EXPECT(is_sorted(*mm));
1452+
}
1453+
1454+
TEST_CASE(hoist_external_inputs_multiple_external_branches)
1455+
{
1456+
// Test with multiple independent external branches that need to be moved
1457+
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
1458+
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
1459+
migraphx::shape s3{migraphx::shape::float_type, {4, 3}};
1460+
migraphx::shape s4{migraphx::shape::float_type, {2, 4}};
1461+
migraphx::program p;
1462+
auto* mm = p.get_main_module();
1463+
1464+
auto a = mm->add_parameter("a", s1);
1465+
auto b = mm->add_parameter("b", s2);
1466+
auto c = mm->add_parameter("c", s4);
1467+
auto d = mm->add_parameter("d", s3);
1468+
1469+
// Start of fusion chain
1470+
auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 4}
1471+
1472+
// External branch 1
1473+
auto ext1 = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {2, 4}
1474+
auto ext2 = add_pointwise(p, "main:pointwise1", {ext1}, single_pointwise("neg")); // {2, 4}
1475+
1476+
// External branch 2
1477+
auto ext3 = add_pointwise(p, "main:pointwise2", {d}, single_pointwise("tanh")); // {4, 3}
1478+
auto ext4 = add_pointwise(p, "main:pointwise3", {ext3}, single_pointwise("abs")); // {4, 3}
1479+
1480+
// Continue fusion chain using external branches
1481+
auto add1 = add_pointwise(p, "main:pointwise4", {dot1, ext2}, single_pointwise("add")); // {2, 4}
1482+
auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add1, ext4); // {2, 4} x {4, 3} = {2, 3}
1483+
1484+
mm->add_return({dot2});
1485+
1486+
mm->hoist_external_inputs(dot1, dot2);
1487+
1488+
EXPECT(is_sorted(*mm));
1489+
1490+
// Verify all external operations moved before dot1
1491+
EXPECT(std::distance(mm->begin(), ext1) < std::distance(mm->begin(), dot1));
1492+
EXPECT(std::distance(mm->begin(), ext2) < std::distance(mm->begin(), dot1));
1493+
EXPECT(std::distance(mm->begin(), ext3) < std::distance(mm->begin(), dot1));
1494+
EXPECT(std::distance(mm->begin(), ext4) < std::distance(mm->begin(), dot1));
1495+
1496+
// Verify fusion chain order preserved
1497+
EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), add1));
1498+
EXPECT(std::distance(mm->begin(), add1) < std::distance(mm->begin(), dot2));
1499+
}
1500+
1501+
TEST_CASE(hoist_external_inputs_long_fusion_chain)
1502+
{
1503+
// Test with a longer fusion chain
1504+
migraphx::shape s{migraphx::shape::float_type, {4, 4}};
1505+
migraphx::program p;
1506+
auto* mm = p.get_main_module();
1507+
1508+
auto a = mm->add_parameter("a", s);
1509+
auto b = mm->add_parameter("b", s);
1510+
auto c = mm->add_parameter("c", s);
1511+
auto d = mm->add_parameter("d", s);
1512+
1513+
// Long fusion chain
1514+
auto op1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
1515+
1516+
// External operations interspersed
1517+
auto ext1 = add_pointwise(p, "main:pointwise0", {d}, single_pointwise("exp"));
1518+
1519+
auto op2 = add_pointwise(p, "main:pointwise1", {op1, c}, single_pointwise("add"));
1520+
1521+
auto ext2 = add_pointwise(p, "main:pointwise2", {ext1}, single_pointwise("log"));
1522+
1523+
auto op3 = add_pointwise(p, "main:pointwise3", {op2}, single_pointwise("relu"));
1524+
1525+
auto ext3 = add_pointwise(p, "main:pointwise4", {ext2}, single_pointwise("abs"));
1526+
1527+
auto op4 = add_pointwise(p, "main:pointwise5", {op3}, single_pointwise("tanh"));
1528+
auto op5 = mm->add_instruction(migraphx::make_op("dot"), op4, ext3);
1529+
1530+
mm->add_return({op5});
1531+
1532+
mm->hoist_external_inputs(op1, op5);
1533+
1534+
EXPECT(is_sorted(*mm));
1535+
1536+
// All external operations moved before op1
1537+
EXPECT(std::distance(mm->begin(), ext1) < std::distance(mm->begin(), op1));
1538+
EXPECT(std::distance(mm->begin(), ext2) < std::distance(mm->begin(), op1));
1539+
EXPECT(std::distance(mm->begin(), ext3) < std::distance(mm->begin(), op1));
1540+
1541+
// Fusion chain order preserved
1542+
EXPECT(std::distance(mm->begin(), op1) < std::distance(mm->begin(), op2));
1543+
EXPECT(std::distance(mm->begin(), op2) < std::distance(mm->begin(), op3));
1544+
EXPECT(std::distance(mm->begin(), op3) < std::distance(mm->begin(), op4));
1545+
EXPECT(std::distance(mm->begin(), op4) < std::distance(mm->begin(), op5));
1546+
}
1547+
1548+
TEST_CASE(hoist_external_inputs_adjacent_instructions)
1549+
{
1550+
// Test edge case where start and end are adjacent
1551+
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
1552+
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
1553+
migraphx::shape s3{migraphx::shape::float_type, {4, 5}};
1554+
migraphx::program p;
1555+
auto* mm = p.get_main_module();
1556+
1557+
auto a = mm->add_parameter("a", s1); // {2, 3}
1558+
auto b = mm->add_parameter("b", s2); // {3, 4}
1559+
auto c = mm->add_parameter("c", s3); // {4, 5}
1560+
1561+
auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 3} x {3, 4} = {2, 4}
1562+
1563+
// External operation between dot1 and dot2
1564+
auto external = add_pointwise(p, "main:pointwise0", {c}, single_pointwise("relu")); // {4, 5}
1565+
1566+
auto dot2 = mm->add_instruction(migraphx::make_op("dot"), dot1, external); // {2, 4} x {4, 5} = {2, 5}
1567+
1568+
mm->add_return({dot2});
1569+
1570+
mm->hoist_external_inputs(dot1, dot2);
1571+
1572+
EXPECT(is_sorted(*mm));
1573+
1574+
// External should be moved before dot1
1575+
EXPECT(std::distance(mm->begin(), external) < std::distance(mm->begin(), dot1));
1576+
EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), dot2));
1577+
}
1578+
14181579
int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 commit comments

Comments
 (0)