@@ -1369,7 +1369,7 @@ TEST_CASE(pathological_dfs_graph_sort)
13691369 EXPECT (is_sorted (m));
13701370}
13711371
1372- TEST_CASE (localized_sort )
1372+ TEST_CASE (hoist_external_inputs )
13731373{
13741374 migraphx::shape s1{migraphx::shape::float_type, {2 , 3 }};
13751375 migraphx::shape s2{migraphx::shape::float_type, {3 , 4 }};
@@ -1394,8 +1394,8 @@ TEST_CASE(localized_sort)
13941394 mm->add_instruction (migraphx::make_op (" transpose" , {{" permutation" , {1 , 0 }}}), dot2);
13951395 mm->add_return ({add, transpose});
13961396
1397- // Perform localized sort between dot1 and dot2
1398- mm->localized_sort (dot1, dot2);
1397+ // Hoist external inputs between dot1 and dot2
1398+ mm->hoist_external_inputs (dot1, dot2);
13991399
14001400 // Verify the module is still topologically sorted overall
14011401 EXPECT (is_sorted (*mm));
@@ -1415,4 +1415,165 @@ TEST_CASE(localized_sort)
14151415 EXPECT (std::distance (mm->begin (), dot2) < std::distance (mm->begin (), transpose));
14161416}
14171417
1418+ TEST_CASE (hoist_external_inputs_no_movement_needed)
1419+ {
1420+ // Test where external dependencies are already before the fusion chain
1421+ migraphx::shape s1{migraphx::shape::float_type, {2 , 3 }};
1422+ migraphx::shape s2{migraphx::shape::float_type, {3 , 4 }};
1423+ migraphx::shape s3{migraphx::shape::float_type, {2 , 4 }};
1424+ migraphx::program p;
1425+ auto * mm = p.get_main_module ();
1426+
1427+ auto a = mm->add_parameter (" a" , s1);
1428+ auto b = mm->add_parameter (" b" , s2);
1429+ auto c = mm->add_parameter (" c" , s3);
1430+
1431+ // External operations already positioned before the fusion chain
1432+ auto external1 = add_pointwise (p, " main:pointwise0" , {a}, single_pointwise (" relu" ));
1433+ auto external2 = add_pointwise (p, " main:pointwise1" , {b}, single_pointwise (" tanh" ));
1434+ auto external3 = add_pointwise (p, " main:pointwise2" , {c}, single_pointwise (" tanh" ));
1435+
1436+ // Fusion chain
1437+ auto dot1 = mm->add_instruction (migraphx::make_op (" dot" ), external1, external2);
1438+ auto add = add_pointwise (p, " main:pointwise3" , {dot1, external3}, single_pointwise (" add" ));
1439+
1440+ mm->add_return ({add});
1441+
1442+ // Record positions before hoist_external_inputs
1443+ auto dot1_pos_before = std::distance (mm->begin (), dot1);
1444+ auto add_pos_before = std::distance (mm->begin (), add);
1445+
1446+ mm->hoist_external_inputs (dot1, add);
1447+
1448+ // Verify positions haven't changed (nothing needed to move)
1449+ EXPECT (std::distance (mm->begin (), dot1) == dot1_pos_before);
1450+ EXPECT (std::distance (mm->begin (), add) == add_pos_before);
1451+ EXPECT (is_sorted (*mm));
1452+ }
1453+
1454+ TEST_CASE (hoist_external_inputs_multiple_external_branches)
1455+ {
1456+ // Test with multiple independent external branches that need to be moved
1457+ migraphx::shape s1{migraphx::shape::float_type, {2 , 3 }};
1458+ migraphx::shape s2{migraphx::shape::float_type, {3 , 4 }};
1459+ migraphx::shape s3{migraphx::shape::float_type, {4 , 3 }};
1460+ migraphx::shape s4{migraphx::shape::float_type, {2 , 4 }};
1461+ migraphx::program p;
1462+ auto * mm = p.get_main_module ();
1463+
1464+ auto a = mm->add_parameter (" a" , s1);
1465+ auto b = mm->add_parameter (" b" , s2);
1466+ auto c = mm->add_parameter (" c" , s4);
1467+ auto d = mm->add_parameter (" d" , s3);
1468+
1469+ // Start of fusion chain
1470+ auto dot1 = mm->add_instruction (migraphx::make_op (" dot" ), a, b); // {2, 4}
1471+
1472+ // External branch 1
1473+ auto ext1 = add_pointwise (p, " main:pointwise0" , {c}, single_pointwise (" relu" )); // {2, 4}
1474+ auto ext2 = add_pointwise (p, " main:pointwise1" , {ext1}, single_pointwise (" neg" )); // {2, 4}
1475+
1476+ // External branch 2
1477+ auto ext3 = add_pointwise (p, " main:pointwise2" , {d}, single_pointwise (" tanh" )); // {4, 3}
1478+ auto ext4 = add_pointwise (p, " main:pointwise3" , {ext3}, single_pointwise (" abs" )); // {4, 3}
1479+
1480+ // Continue fusion chain using external branches
1481+ auto add1 = add_pointwise (p, " main:pointwise4" , {dot1, ext2}, single_pointwise (" add" )); // {2, 4}
1482+ auto dot2 = mm->add_instruction (migraphx::make_op (" dot" ), add1, ext4); // {2, 4} x {4, 3} = {2, 3}
1483+
1484+ mm->add_return ({dot2});
1485+
1486+ mm->hoist_external_inputs (dot1, dot2);
1487+
1488+ EXPECT (is_sorted (*mm));
1489+
1490+ // Verify all external operations moved before dot1
1491+ EXPECT (std::distance (mm->begin (), ext1) < std::distance (mm->begin (), dot1));
1492+ EXPECT (std::distance (mm->begin (), ext2) < std::distance (mm->begin (), dot1));
1493+ EXPECT (std::distance (mm->begin (), ext3) < std::distance (mm->begin (), dot1));
1494+ EXPECT (std::distance (mm->begin (), ext4) < std::distance (mm->begin (), dot1));
1495+
1496+ // Verify fusion chain order preserved
1497+ EXPECT (std::distance (mm->begin (), dot1) < std::distance (mm->begin (), add1));
1498+ EXPECT (std::distance (mm->begin (), add1) < std::distance (mm->begin (), dot2));
1499+ }
1500+
1501+ TEST_CASE (hoist_external_inputs_long_fusion_chain)
1502+ {
1503+ // Test with a longer fusion chain
1504+ migraphx::shape s{migraphx::shape::float_type, {4 , 4 }};
1505+ migraphx::program p;
1506+ auto * mm = p.get_main_module ();
1507+
1508+ auto a = mm->add_parameter (" a" , s);
1509+ auto b = mm->add_parameter (" b" , s);
1510+ auto c = mm->add_parameter (" c" , s);
1511+ auto d = mm->add_parameter (" d" , s);
1512+
1513+ // Long fusion chain
1514+ auto op1 = mm->add_instruction (migraphx::make_op (" dot" ), a, b);
1515+
1516+ // External operations interspersed
1517+ auto ext1 = add_pointwise (p, " main:pointwise0" , {d}, single_pointwise (" exp" ));
1518+
1519+ auto op2 = add_pointwise (p, " main:pointwise1" , {op1, c}, single_pointwise (" add" ));
1520+
1521+ auto ext2 = add_pointwise (p, " main:pointwise2" , {ext1}, single_pointwise (" log" ));
1522+
1523+ auto op3 = add_pointwise (p, " main:pointwise3" , {op2}, single_pointwise (" relu" ));
1524+
1525+ auto ext3 = add_pointwise (p, " main:pointwise4" , {ext2}, single_pointwise (" abs" ));
1526+
1527+ auto op4 = add_pointwise (p, " main:pointwise5" , {op3}, single_pointwise (" tanh" ));
1528+ auto op5 = mm->add_instruction (migraphx::make_op (" dot" ), op4, ext3);
1529+
1530+ mm->add_return ({op5});
1531+
1532+ mm->hoist_external_inputs (op1, op5);
1533+
1534+ EXPECT (is_sorted (*mm));
1535+
1536+ // All external operations moved before op1
1537+ EXPECT (std::distance (mm->begin (), ext1) < std::distance (mm->begin (), op1));
1538+ EXPECT (std::distance (mm->begin (), ext2) < std::distance (mm->begin (), op1));
1539+ EXPECT (std::distance (mm->begin (), ext3) < std::distance (mm->begin (), op1));
1540+
1541+ // Fusion chain order preserved
1542+ EXPECT (std::distance (mm->begin (), op1) < std::distance (mm->begin (), op2));
1543+ EXPECT (std::distance (mm->begin (), op2) < std::distance (mm->begin (), op3));
1544+ EXPECT (std::distance (mm->begin (), op3) < std::distance (mm->begin (), op4));
1545+ EXPECT (std::distance (mm->begin (), op4) < std::distance (mm->begin (), op5));
1546+ }
1547+
1548+ TEST_CASE (hoist_external_inputs_adjacent_instructions)
1549+ {
1550+ // Test edge case where start and end are adjacent
1551+ migraphx::shape s1{migraphx::shape::float_type, {2 , 3 }};
1552+ migraphx::shape s2{migraphx::shape::float_type, {3 , 4 }};
1553+ migraphx::shape s3{migraphx::shape::float_type, {4 , 5 }};
1554+ migraphx::program p;
1555+ auto * mm = p.get_main_module ();
1556+
1557+ auto a = mm->add_parameter (" a" , s1); // {2, 3}
1558+ auto b = mm->add_parameter (" b" , s2); // {3, 4}
1559+ auto c = mm->add_parameter (" c" , s3); // {4, 5}
1560+
1561+ auto dot1 = mm->add_instruction (migraphx::make_op (" dot" ), a, b); // {2, 3} x {3, 4} = {2, 4}
1562+
1563+ // External operation between dot1 and dot2
1564+ auto external = add_pointwise (p, " main:pointwise0" , {c}, single_pointwise (" relu" )); // {4, 5}
1565+
1566+ auto dot2 = mm->add_instruction (migraphx::make_op (" dot" ), dot1, external); // {2, 4} x {4, 5} = {2, 5}
1567+
1568+ mm->add_return ({dot2});
1569+
1570+ mm->hoist_external_inputs (dot1, dot2);
1571+
1572+ EXPECT (is_sorted (*mm));
1573+
1574+ // External should be moved before dot1
1575+ EXPECT (std::distance (mm->begin (), external) < std::distance (mm->begin (), dot1));
1576+ EXPECT (std::distance (mm->begin (), dot1) < std::distance (mm->begin (), dot2));
1577+ }
1578+
14181579int main (int argc, const char * argv[]) { test::run (argc, argv); }
0 commit comments