diff --git a/shardy/dialect/sdy/transforms/propagation/BUILD b/shardy/dialect/sdy/transforms/propagation/BUILD index c201af3c..c91099fd 100644 --- a/shardy/dialect/sdy/transforms/propagation/BUILD +++ b/shardy/dialect/sdy/transforms/propagation/BUILD @@ -250,9 +250,10 @@ cc_test( srcs = ["aggressive_factor_propagation_test.cc"], deps = [ ":aggressive_factor_propagation", - ":basic_factor_propagation", + ":factor_propagation", ":sharding_projection", ":testing_utils", + ":utils", "//shardy/dialect/sdy/ir:dialect", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc index 72dd5e5c..c120e4f0 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc @@ -23,38 +23,65 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h" #include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h" #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" namespace mlir { namespace sdy { -AxesPerFactor -AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors( - const ShardingProjection& projection, PropagationDirection direction, +namespace { + +bool updateTensorSharding(ShardingProjection& projection, int64_t tensorIndex, + int64_t factorIndex, ArrayRef newAxes) { + if (tensorIndex < projection.getNumOperands()) { + return projection.updateOperandSharding(tensorIndex, factorIndex, newAxes); + } + return projection.updateResultSharding( + tensorIndex - projection.getNumOperands(), factorIndex, newAxes); +} + +} // namespace + +UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings( + ShardingProjection& projection, PropagationDirection direction, ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const { + UpdateTensorShardings result{ + .updateOperands = BitVector(projection.getNumOperands()), + .updateResults = BitVector(projection.getNumResults())}; if (direction == PropagationDirection::NONE) { - return AxesPerFactor(factorSizes.size()); + return result; } - // Finds the compatible major axes ignoring conflicts. - AxesPerFactor result; - result.reserve(factorSizes.size()); + // Find the compatible major axes ignoring conflicts. + SmallVector> axesPerFactor; + axesPerFactor.reserve(factorSizes.size()); + bool allElementsAreEmpty = true; for (int64_t i = 0; i < factorSizes.size(); ++i) { - result.push_back(getCompatibleMajorAxes(projection, i, direction, op)); + SmallVector& axes = axesPerFactor.emplace_back( + getCompatibleMajorAxes(projection, i, direction, op)); + if (!axes.empty()) { + allElementsAreEmpty = false; + } + } + if (allElementsAreEmpty) { + return result; } - // Removes the conflicts within every single factor. This strategy and - // `BasicFactorPropagation` handles conflicts within a factor in the same way. - for (const TensorFactorShardings& tensorFactorShardings : - llvm::concat(projection.getOperands(), - projection.getResults())) { - for (const auto& [factorIndex, factorSharding] : - tensorFactorShardings.factorIndexToSharding) { + // The propagation on each tensor is independent. This strategy can propagate + // different shardings to different tensors along the same factor. Examples + // are provided in the docstring of this class. + for (const auto& [tensorIndex, tensorFactorShardings] : + llvm::enumerate(llvm::concat( + projection.getOperands(), projection.getResults()))) { + // Propagate the axes got in Step 1, and resolve conflicts within a factor. + FactorIndexToSharding newSharding = + tensorFactorShardings.factorIndexToSharding; + BitVector factorUpdated(factorSizes.size()); + for (auto& [factorIndex, factorSharding] : newSharding) { + SmallVector newAxes = axesPerFactor[factorIndex]; truncateAxesByRemovingConflicts( - result[factorIndex], + newAxes, [&, factorIndex = factorIndex, &factorSharding = factorSharding]( AxisRefAttr axisRef, int64_t shardedSize) { return compatiblePrefixNoConflictsWithinFactor( @@ -62,107 +89,33 @@ AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors( shardedSize, factorSizes[factorIndex]); }, mesh, conservativePropagation); + if (shouldUpdate(factorSharding.axisRefs, newAxes)) { + factorSharding.axisRefs = newAxes; + factorUpdated.set(factorIndex); + } } - } - // Removes the conflicts across factors, where this strategy and - // `BasicFactorPropagation` diverge. - // - // With `BasicFactorPropagation`, the compatible axes of a factor Fi cannot - // overlap with the existing sharding axes or the overflow axes related to all - // other factors. This criterion is considered for all tensors, no matter if - // Fi is mapped to the tensor or not. The table below shows the criterion: - // - // existing sharding axes & overflow axes new sharding axes - // factor in tensor remove overlap - - // factor not in tensor remove overlap - - // - // On the contrary, `AggressiveFactorPropagation` has the following criterion: - // - // existing sharding axes & overflow axes new sharding axes - // factor in tensor remove overlap remove overlap - // factor not in tensor - - - // - // There are two differences: - // - // 1. `BasicFactorPropagation` removes the overlap between the compatible axes - // of a factor Fi with the existing sharding axes and overflow axes in a - // tensor Tj even if Fi is not in Tj. `AggressiveFactorPropagation` does not - // remove this overlap if Fi is not in Tj. `BasicFactorPropagation` is too - // strict, since we cannot propagate sharding axes to Tj along Fi. - // - // `AggressiveFactorPropagation` cannot handle the following case if we only - // have difference #1. `-` means that the factor is not mapped to the tensor. - // After removing conflicts within factors, we will propagate "x" to T2 along - // F0 and F1 at the same time, which induces a conflict. To resolve this - // conflict, we have difference #2. - // - // F0 F1 - // T0 "x" - - // T1 - "x" - // T2 ? ? - // - // 2. `AggressiveFactorPropagation` removes the overlap between compatible - // axes of a factor Fi with the potential new sharding axes of other factors - // in Tj if Fi is in Tj. Thus, it is safe to propagate the axes to Tj along Fi - // without conflicts with other factors. In the example, we will not propagate - // "x" along F0 or F1 since their potential new sharding axes overlap. - // - // The potential new sharding axes are saved in `resultSnapshot`. It is a hard - // copy since we need to handle the following case. - // - // F0 F1 F2 - // T0 "x" - - - // T1 - "x" - - // T2 - - "x" - // T3 ? ? ? - // - // The `result` and `resultSnapshot` is [["x"], ["x"], ["x"]] before removing - // conflicts across factors. After removing conflicts between F0/F1 and other - // factors, `result` is [[], [], ["x"]]. When we remove conflicts between F2 - // and other factors, if we use `result` as the potential new sharding axes, - // we will not remove "x" for F2 because it is no longer present in 'result' - // for F0 and F1. We have to use `resultSnapshot` to save the potential new - // sharding axes and remove "x" for F2. - const AxesPerFactor resultSnapshot = result; - for (const TensorFactorShardings& tensorFactorSharding : - llvm::concat(projection.getOperands(), - projection.getResults())) { - for (const auto& [factorIndex, factorSharding] : - tensorFactorSharding.factorIndexToSharding) { + // Resolve conflicts (overlapping sharding axes) between factors. + bool tensorUpdated = false; + for (const int64_t factorIndex : factorUpdated.set_bits()) { + SmallVector newAxes = newSharding[factorIndex].axisRefs; truncateAxesByRemovingConflicts( - result[factorIndex], + newAxes, [&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) { return compatiblePrefixNoConflictsAcrossFactors( - axisRef, tensorFactorSharding.factorIndexToSharding, - factorIndex, resultSnapshot); + axisRef, newSharding, factorIndex); }, mesh, conservativePropagation); + tensorUpdated |= + updateTensorSharding(projection, tensorIndex, factorIndex, newAxes); } - } - return result; -} - -UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings( - ShardingProjection& projection, PropagationDirection direction, - ArrayRef factorSizes, MeshAttr mesh, Operation* op, - bool conservativePropagation) const { - UpdateTensorShardings result{ - .updateOperands = BitVector(projection.getNumOperands()), - .updateResults = BitVector(projection.getNumResults())}; - - // We get the compatible major sharding axes for all factors. - AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors( - projection, direction, factorSizes, mesh, op, conservativePropagation); - - for (auto [factorIndex, axesToPropagate] : llvm::enumerate(axesPerFactor)) { - // Update all shardings along this factor if possible. - auto [updateOperandForFactor, updateResultForFactor] = - projection.updateSharding(factorIndex, axesToPropagate); - - result.updateOperands |= updateOperandForFactor; - result.updateResults |= updateResultForFactor; + if (tensorIndex < projection.getNumOperands()) { + result.updateOperands[tensorIndex] = tensorUpdated; + } else { + result.updateResults[tensorIndex - projection.getNumOperands()] = + tensorUpdated; + } } return result; diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h index b87ae6bc..97d3481e 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h @@ -22,40 +22,53 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h" +#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h" #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" namespace mlir { namespace sdy { -// An aggressive strategy of propagating sharding axes along factors. +// An aggressive strategy of propagating sharding axes along factors. There are +// two main differences from `BasicFactorPropagation`. // -// This strategy is the same as `BasicFactorPropagation` on the conflicts within -// a factor. They are different on the conflicts across factors. +// `BasicFactorPropagation` propagates the same sharding axes to all tensors +// along a factor. This strategy can propagate different sharding axes to +// different tensors along the same factor. For example, Tensors T0, T1, T2 +// contain Factor F0. T0/F0 is already sharded along ["a", "b"], and "b" is +// already used by T2 ("b" can be explicitly replicated, or it is used to shard +// another factor). `BasicFactorPropagation` propagates ["a"] to both T1/F0 and +// T2/F0, while this strategy propagates ["a", "b"] to T1/F0 and ["a"] to T2/F0, +// respectively. If T2/F0 is closed, `BasicFactorPropagation` propagates +// nothing, while this strategy propagates nothing to T2/F0 and still propagates +// ["a", "b"] to T1/F0. // -// `BasicFactorPropagation` considers the conflicts across factors with a strict -// criterion. The result cannot overlap with the sharded axes or overflow axes -// related to all other factors. This aggressive strategy ignores "fake -// conflicts", which are propagation choices that can co-exist. This aggressive -// strategy ensures that the resultant axes can be propagated to all tensors -// containing the factor. Several examples of fake conflicts: +// `BasicFactorPropagation` is conservative in terms of conflicts across +// factors. The overlapped axis between factors cannot be propagated. This +// strategy is more aggressive by allowing the overlapped axis being propagated +// along different factors if there is no overlapped axis in the result +// shardings. // -// 1. An axis is in factors Fi and Fj. If it is infeasible to propagate that -// axis along factor Fi, we may propagate that axis along factor Fj if all the -// destination tensors have not used that axis. +// Let us take C = dot(A, B) as an example. F0 is the factor corresponding to a +// non-contracting dimension of A. F1 corresponds to a non-contracting dimension +// of B. F2 corresponds to a contracting dimension. "-" means that the tensor +// does not contain the factor. // -// 2. Two factors Fi and Fj do not co-exist in any tensor, so they never -// interfere with each other. If Fi and Fj are sharded along the same axis, we -// can propagate that axis along both factors. +// F0 F1 F2 +// A "a" - +// B - +// C "a" - +// Case 1. Fake conflict. `BasicFactorPropagation` propagates nothing, while +// this strategy propagates "a" to B/F1. // -// Although fake conflicts can co-exist without inference, we may still need to -// all-gather some tensors. +// F0 F1 F2 +// A "a" - +// B - "a" +// C - +// Case 2. Real conflict. Both `BasicFactorPropagation` and this strategy +// propagate nothing. We can propagate "a" to C/F0 or C/F1, which is illegal +// since "a" cannot be used twice in C. class AggressiveFactorPropagation : public BasicFactorPropagation { public: - AxesPerFactor getCompatibleMajorShardingAxesForAllFactors( - const ShardingProjection& projection, PropagationDirection direction, - ArrayRef factorSizes, MeshAttr mesh, Operation* op, - bool conservativePropagation) const override; - UpdateTensorShardings propagateFactorShardings( ShardingProjection& projection, PropagationDirection direction, ArrayRef factorSizes, MeshAttr mesh, Operation* op, diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc index 24f49fc7..25dddc04 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc @@ -17,14 +17,14 @@ limitations under the License. #include #include -#include -#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h" +#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h" #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" #include "shardy/dialect/sdy/transforms/propagation/testing_utils.h" +#include "shardy/dialect/sdy/transforms/propagation/utils.h" #include #include @@ -35,42 +35,37 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; -class GetCompatibleMajorShardingAxesTest : public PropagationTestBase { +class AggressiveFactorPropagationTest : public PropagationTestBase { protected: - AxesPerFactor getCompatibleMajorShardingAxesForAllFactors( - ShardingProjection projection, - const BasicFactorPropagation& factorPropagation, int64_t numFactors) { - AxesPerFactor result = - factorPropagation.getCompatibleMajorShardingAxesForAllFactors( - projection, PropagationDirection::BOTH, - SmallVector(numFactors, 1), /*mesh=*/nullptr, - /*op=*/nullptr, /*conservativePropagation=*/false); - EXPECT_EQ(result.size(), numFactors); - return result; + UpdateTensorShardings propagateFactorShardings( + ShardingProjection& projection, int64_t numFactors, + PropagationDirection direction = PropagationDirection::BOTH, + MeshAttr mesh = nullptr, bool conservativePropagation = false, + Operation* op = nullptr) { + return AggressiveFactorPropagation().propagateFactorShardings( + projection, direction, SmallVector(numFactors, 1), mesh, op, + conservativePropagation); } - bool basicAndAggressiveFactorPropagationSameResult( - ShardingProjection projection, int64_t numFactors, + UpdateTensorShardings propagateFactorShardings( + ShardingProjection& projection, ArrayRef factorSizes, PropagationDirection direction = PropagationDirection::BOTH, - ArrayRef factorSizes = ArrayRef(), - ArrayRef> meshAxes = {}) { - const AxesPerFactor result1 = getCompatibleMajorShardingAxesForAllFactors( - projection, BasicFactorPropagation(), numFactors); - const AxesPerFactor result2 = getCompatibleMajorShardingAxesForAllFactors( - projection, AggressiveFactorPropagation(), numFactors); - return result1 == result2; + MeshAttr mesh = nullptr, Operation* op = nullptr, + bool conservativePropagation = false) { + return AggressiveFactorPropagation().propagateFactorShardings( + projection, direction, factorSizes, mesh, op, conservativePropagation); } }; -TEST_F(GetCompatibleMajorShardingAxesTest, RealAndFakeConflicts) { +TEST_F(AggressiveFactorPropagationTest, RealAndFakeConflicts) { ShardingProjection projection( /*operands=*/ { {.factorIndexToSharding = { {0, {.axisRefs = {createAxis("a")}}}, - {2, {.axisRefs = {createAxis("b")}}}, - {4, {.axisRefs = {createAxis("c")}}}, + {2, {.axisRefs = {createAxis("c")}}}, + {4, {.axisRefs = {createAxis("b")}}}, {6, {.axisRefs = {createAxis("e")}}}, {7, {.axisRefs = {}}}, {8, {.axisRefs = {createAxis("f")}}}, @@ -80,9 +75,9 @@ TEST_F(GetCompatibleMajorShardingAxesTest, RealAndFakeConflicts) { { {1, {.axisRefs = {createAxis("a")}}}, {2, {.axisRefs = {}}}, - {3, {.axisRefs = {createAxis("b")}}}, + {3, {.axisRefs = {createAxis("c")}}}, {4, {.axisRefs = {}}}, - {5, {.axisRefs = {createAxis("c")}}}, + {5, {.axisRefs = {createAxis("b")}}}, {6, {.axisRefs = {}}}, {7, {.axisRefs = {createAxis("e")}}}, {9, {.axisRefs = {createAxis("g")}}}, @@ -94,9 +89,9 @@ TEST_F(GetCompatibleMajorShardingAxesTest, RealAndFakeConflicts) { { {0, {.axisRefs = {}}}, {1, {.axisRefs = {}}}, - {2, {.axisRefs = {}, .isClosed = true}}, + {2, {.axisRefs = {}, .overflowAxes = {createAxis("d")}}}, {3, {.axisRefs = {}}}, - {4, {.axisRefs = {}, .overflowAxes = {createAxis("d")}}}, + {4, {.axisRefs = {}, .isClosed = true}}, {5, {.axisRefs = {}}}, {6, {.axisRefs = {}, .isClosed = true}}, {7, {.axisRefs = {}}}, @@ -104,94 +99,68 @@ TEST_F(GetCompatibleMajorShardingAxesTest, RealAndFakeConflicts) { {9, {.axisRefs = {}}}, }}, }); + ShardingProjection projectionExpected( + /*operands=*/{projection.getOperand(0), projection.getOperand(1)}, + /*results=*/{ + {.factorIndexToSharding = + { + {0, {.axisRefs = {}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {}, .overflowAxes = {createAxis("d")}}}, + {3, {.axisRefs = {createAxis("c")}}}, + {4, {.axisRefs = {}, .isClosed = true}}, + {5, {.axisRefs = {createAxis("b")}}}, + {6, {.axisRefs = {}, .isClosed = true}}, + {7, {.axisRefs = {createAxis("e")}}}, + {8, {.axisRefs = {createAxis("f")}}}, + {9, {.axisRefs = {createAxis("g")}}}, + }}, + }); - // Basic strategy does not propagate anything for this case. - AxesPerFactor resultWithBasicStrategy = - getCompatibleMajorShardingAxesForAllFactors(projection, - BasicFactorPropagation(), 11); - for (ArrayRef element : resultWithBasicStrategy) { - EXPECT_THAT(element, IsEmpty()); - } - - AxesPerFactor resultWithAggressiveStrategy = - getCompatibleMajorShardingAxesForAllFactors( - projection, AggressiveFactorPropagation(), 11); - - // Axis "a" is in factors 0 and 1, which co-exists in the result. We can - // propagate "a" along factor 0 or 1 to the result. The real conflicts - // prohibit further propagation along these two factors. - EXPECT_THAT(resultWithAggressiveStrategy[0], IsEmpty()); - EXPECT_THAT(resultWithAggressiveStrategy[1], IsEmpty()); - - // Axis "b" is in factors 2 and 3. Since we cannot propagate "b" along - // factor 2 (the factor sharding is closed in the result), we can propagate - // "b" along factor 3. - EXPECT_THAT(resultWithAggressiveStrategy[2], IsEmpty()); - EXPECT_THAT(resultWithAggressiveStrategy[3], ElementsAre(AxisRefIs("b"))); - - // Axis "c" is in factors 4 and 5. Since we cannot propagate "c" along - // factor 4 (the factor sharding has overflow axes in the result), we can - // propagate "c" along factor 5. - EXPECT_THAT(resultWithAggressiveStrategy[4], IsEmpty()); - EXPECT_THAT(resultWithAggressiveStrategy[5], ElementsAre(AxisRefIs("c"))); - - // Axis "e" is in factors 6 and 7. We cannot propagate "e" along factor 6 - // since the factor sharding is closed in the result. We cannot propagate - // "e" along factor 7 since operand 0 contains factor 7 and is already - // sharded along factor 7. - EXPECT_THAT(resultWithAggressiveStrategy[6], IsEmpty()); - EXPECT_THAT(resultWithAggressiveStrategy[7], IsEmpty()); - - // Factor 10 already contains axes "f" and "g". Factor does not appear in - // the result. Hence, we can propagate "f" and "g" to result along factor 8 - // and 9, respectively. - EXPECT_THAT(resultWithAggressiveStrategy[8], ElementsAre(AxisRefIs("f"))); - EXPECT_THAT(resultWithAggressiveStrategy[9], ElementsAre(AxisRefIs("g"))); - EXPECT_THAT(resultWithAggressiveStrategy[10], IsEmpty()); + // Axis "a" may be propagated to the result along factors 0 or 1, which forms + // a real conflict. Thus, we do not apply either of propagation choices. + // + // Other conflicts are fake. We can propagate other axes as much as possible. + // Axes "c", "b", "e", "f", "g" can be propagated to the result along factors + // 3, 5, 7, 8, 9, respectively, since these axes cannot propagated to the + // result along other factors. Also the sharding after propagation is valid + // (sharding axes are not overlapped with each other). + // + // Propagation on different factors are independent. Although we cannot + // propagate "e" to the Operand 0 along factor 7, we still propagate "e" to + // the result along factor 7. + auto [updateOperands, updateResults] = + propagateFactorShardings(projection, 11); + EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); + EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0)); + EXPECT_EQ(projection, projectionExpected); } -TEST_F(GetCompatibleMajorShardingAxesTest, TwoFactorsDoNotCoExistInAnyTensor) { +TEST_F(AggressiveFactorPropagationTest, TwoFactorsDoNotCoExistInAnyTensor) { ShardingProjection projection( /*operands=*/ { - {.factorIndexToSharding = - { - {0, {.axisRefs = {createAxis("a")}}}, - }}, - {.factorIndexToSharding = - { - {1, {.axisRefs = {createAxis("a")}}}, - }}, + {.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}}}, + {.factorIndexToSharding = {{1, {.axisRefs = {createAxis("a")}}}}}, }, /*results=*/{ - {.factorIndexToSharding = - { - {0, {.axisRefs = {}}}, - }}, - {.factorIndexToSharding = - { - {1, {.axisRefs = {}}}, - }}, + {.factorIndexToSharding = {{0, {.axisRefs = {}}}}}, + {.factorIndexToSharding = {{1, {.axisRefs = {}}}}}, }); + ShardingProjection projectionExpected( + /*operands=*/{projection.getOperand(0), projection.getOperand(1)}, + /*results=*/{projection.getOperand(0), projection.getOperand(1)}); - // Basic strategy does not propagate anything for this case. - AxesPerFactor resultWithBasicStrategy = - getCompatibleMajorShardingAxesForAllFactors(projection, - BasicFactorPropagation(), 2); - for (ArrayRef element : resultWithBasicStrategy) { - EXPECT_THAT(element, IsEmpty()); - } - - // Factors 0 and 1 do not co-exist in any tensor. Hence, we can propagate - // axis "a" along both factors. - AxesPerFactor resultWithAggressiveStrategy = - getCompatibleMajorShardingAxesForAllFactors( - projection, AggressiveFactorPropagation(), 2); - EXPECT_THAT(resultWithAggressiveStrategy[0], ElementsAre(AxisRefIs("a"))); - EXPECT_THAT(resultWithAggressiveStrategy[1], ElementsAre(AxisRefIs("a"))); + // We can propagate axis "a" along both factors since the two factors do not + // co-exist in any tensor. + auto [updateOperands, updateResults] = + propagateFactorShardings(projection, 2); + EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); + EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0, 1)); + EXPECT_EQ(projection, projectionExpected); } -TEST_F(GetCompatibleMajorShardingAxesTest, +TEST_F(AggressiveFactorPropagationTest, ConflictsBetweenDifferentFactorsAndReplicated) { ShardingProjection projection( /*operands=*/ @@ -242,124 +211,149 @@ TEST_F(GetCompatibleMajorShardingAxesTest, }}, }); - int64_t numFactors = 6; - AxesPerFactor resultWithBasicStrategy = - getCompatibleMajorShardingAxesForAllFactors( - projection, BasicFactorPropagation(), numFactors); - AxesPerFactor resultWithAggressiveStrategy = - getCompatibleMajorShardingAxesForAllFactors( - projection, AggressiveFactorPropagation(), numFactors); - - // The compatible major axes for factor 0 are ["a", "b", "c"], but the 2nd - // operand, which isn't mapped to factor 0, has a different factor (1) that - // is sharded along "b". - EXPECT_THAT(resultWithBasicStrategy[0], ElementsAre(AxisRefIs("a"))); - // Since factors 0 and 1 does not co-exist in the same operand, we can - // ignore their conflicts. - EXPECT_THAT(resultWithAggressiveStrategy[0], - ElementsAre(AxisRefIs("a"), AxisRefIs("b"), AxisRefIs("c"))); - - // Axis "b" appears in factors 0 and 1, which does not co-exist in the same - // operand. Hence, we can propagate axis "b" along factors 0 and 1, - // respectively. `getAxesWithConservativeStrategy` treat it as a conflict, - // while `getAxesWithAggressiveStrategy` ignore - // this fake conflict. - EXPECT_THAT(resultWithBasicStrategy[1], IsEmpty()); - EXPECT_THAT(resultWithAggressiveStrategy[1], ElementsAre(AxisRefIs("b"))); + ShardingProjection projectionExpected( + /*operands=*/ + { + {.factorIndexToSharding = + { + {0, + {.axisRefs = {createAxis("a"), createAxis("b"), + createAxis("c")}}}, + {3, + {.axisRefs = {createSubAxis("h", 1, 2), createAxis("i")}}}, + }}, + projection.getOperand(1), + {.factorIndexToSharding = + { + {0, + {.axisRefs = {createAxis("a"), createAxis("b"), + createAxis("c")}}}, + {4, + {.axisRefs = {createSubAxis("j", 1, 8), + createSubAxis("k", 2, 4)}}}, + }}, + }, + /*results=*/{ + {.factorIndexToSharding = + { + {0, + {.axisRefs = {createAxis("a"), createAxis("b"), + createAxis("c"), createAxis("d")}}}, + {3, + {.axisRefs = {createSubAxis("h", 1, 2), + createSubAxis("i", 1, 2)}}}, + }, + .replicatedAxes = {createSubAxis("h", 2, 4), + createSubAxis("i", 2, 2)}}, + projection.getResult(1), + {.factorIndexToSharding = + { + {4, {.axisRefs = {createSubAxis("j", 1, 8)}}}, + {5, {.axisRefs = {createSubAxis("k", 1, 4)}}}, + }}, + }); - // For other factors, the results are the same. - for (int64_t i = 2; i < numFactors; i++) { - EXPECT_TRUE(resultWithBasicStrategy[i] == resultWithAggressiveStrategy[i]); - } + auto [updateOperands, updateResults] = + propagateFactorShardings(projection, 6); + EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(0, 2)); + EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0, 2)); + EXPECT_EQ(projection, projectionExpected); } -TEST_F(GetCompatibleMajorShardingAxesTest, FullAxesConflictsOnlyForSameFactor) { +TEST_F(AggressiveFactorPropagationTest, NewAxesConflict) { ShardingProjection projection( /*operands=*/ { {.factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a")}}}, - {3, {.axisRefs = {createAxis("h")}}}, - }}, + {0, + {.axisRefs = {createAxis("a"), createAxis("b"), + createAxis("c")}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {}, .isClosed = true}}, + {3, {.axisRefs = {}}}, + }, + .replicatedAxes = {createAxis("d")}}, {.factorIndexToSharding = { {0, {.axisRefs = {}}}, - {2, {.axisRefs = {createAxis("g")}}}, - {3, {.axisRefs = {createAxis("i")}}}, + {1, {.axisRefs = {createAxis("b"), createAxis("a")}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {createAxis("d")}}}, + }}, + {.factorIndexToSharding = + { + {0, {.axisRefs = {}, .isClosed = true}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {createAxis("c"), createAxis("a")}}}, + {3, {.axisRefs = {}}}, }}, - {.factorIndexToSharding = {{1, {.axisRefs = {createAxis("e")}}}}}, }, /*results=*/{ {.factorIndexToSharding = { - {0, - {.axisRefs = {createAxis("a"), createAxis("b"), - createAxis("c")}}}, - {1, {.axisRefs = {createAxis("e"), createAxis("f")}}}, + {0, {.axisRefs = {}}}, + {1, {.axisRefs = {}, .isClosed = true}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, }}, {.factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("d")}}}, + {1, {.axisRefs = {}}}, {2, {.axisRefs = {}}}, - {3, {.axisRefs = {createAxis("i"), createAxis("j")}}}, + {3, {.axisRefs = {}}}, }}, }); - // Two strategies have the same criterion on the conflicts within a single - // factor. - EXPECT_TRUE(basicAndAggressiveFactorPropagationSameResult(projection, 4)); -} -TEST_F(GetCompatibleMajorShardingAxesTest, SubAxesConflictsOnlyForSameFactor) { - ShardingProjection projection( + ShardingProjection projectionExpected( /*operands=*/ { + projection.getOperand(0), {.factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, - {1, - {.axisRefs = {createSubAxis("c", 2, 2), createAxis("d")}}}, - {4, {.axisRefs = {createSubAxis("i", 1, 4)}}}, - {5, {.axisRefs = {createSubAxis("k", 1, 8)}}}, + {0, {.axisRefs = {}}}, + {1, {.axisRefs = {createAxis("b"), createAxis("a")}}}, + {2, {.axisRefs = {createAxis("c")}}}, + {3, {.axisRefs = {createAxis("d")}}}, }}, {.factorIndexToSharding = { - {0, {.axisRefs = {}}}, - {1, {.axisRefs = {createSubAxis("c", 2, 4)}}}, - {3, - {.axisRefs = {createAxis("g"), createSubAxis("h", 1, 8)}}}, - {5, - {.axisRefs = {createSubAxis("k", 1, 4)}, .isClosed = true}}, + {0, {.axisRefs = {}, .isClosed = true}}, + {1, {.axisRefs = {createAxis("b")}}}, + {2, {.axisRefs = {createAxis("c"), createAxis("a")}}}, + {3, {.axisRefs = {createAxis("d")}}}, }}, }, /*results=*/{ {.factorIndexToSharding = { - {0, {.axisRefs = {createSubAxis("a", 1, 2)}}}, - {2, - {.axisRefs = {createSubAxis("e", 2, 4), - createSubAxis("f", 4, 2)}}}, - {3, - {.axisRefs = {createAxis("g"), createSubAxis("h", 1, 4), - createAxis("i")}}}, - {5, {.axisRefs = {createSubAxis("k", 1, 2)}}}, + {0, {.axisRefs = {}}}, + {1, {.axisRefs = {}, .isClosed = true}}, + {2, {.axisRefs = {createAxis("c")}}}, + {3, {.axisRefs = {createAxis("d")}}}, }}, {.factorIndexToSharding = { - {1, - {.axisRefs = {createSubAxis("c", 2, 4), createAxis("d")}}}, - {2, - {.axisRefs = {createSubAxis("e", 2, 4), - createSubAxis("f", 4, 8)}}}, - {4, {.axisRefs = {createSubAxis("j", 2, 4)}}}, - {5, {.axisRefs = {}}}, + {0, + {.axisRefs = {createAxis("a"), createAxis("b"), + createAxis("d")}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {createAxis("c")}}}, + {3, {.axisRefs = {}}}, }}, }); - // Two strategies have the same criterion on the conflicts within a single - // factor. - EXPECT_TRUE(basicAndAggressiveFactorPropagationSameResult(projection, 6)); + + // “a” can be propagated to the Result 0 along either Factor 0 or Factor 2. + // This strategy truncate “a” for both F0 and F2 in Result 0. Namely, this + // strategy does not resolve real conflicts across factors. + auto [updateOperands, updateResults] = + propagateFactorShardings(projection, 4); + EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(1, 2)); + EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0, 1)); + EXPECT_EQ(projection, projectionExpected); } } // namespace diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc index 78954d6d..975267e4 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc @@ -59,7 +59,7 @@ std::optional getPrefixWithoutOverlap( std::optional BasicFactorPropagation::compatiblePrefixNoConflictsAcrossFactors( AxisRefAttr axisRef, const FactorIndexToSharding& factorIndexToSharding, - int64_t factorIndex, AxesPerFactorRef newAxesPerFactor) const { + int64_t factorIndex) const { AxisRefAttr result = axisRef; for (const auto& [otherFactorIndex, shardings] : factorIndexToSharding) { if (otherFactorIndex != factorIndex) { @@ -67,11 +67,6 @@ BasicFactorPropagation::compatiblePrefixNoConflictsAcrossFactors( result, getPrefixWithoutOverlap(result, shardings.overflowAxes)); ASSIGN_OR_RETURN_IF_NULLOPT( result, getPrefixWithoutOverlap(result, shardings.axisRefs)); - if (!newAxesPerFactor.empty()) { - ASSIGN_OR_RETURN_IF_NULLOPT( - result, getPrefixWithoutOverlap( - result, newAxesPerFactor[otherFactorIndex])); - } } } return result; @@ -389,21 +384,6 @@ SmallVector BasicFactorPropagation::getCompatibleMajorShardingAxes( return resultAxes; } -AxesPerFactor -BasicFactorPropagation::getCompatibleMajorShardingAxesForAllFactors( - const ShardingProjection& projection, PropagationDirection direction, - ArrayRef factorSizes, MeshAttr mesh, Operation* op, - bool conservativePropagation) const { - AxesPerFactor result; - result.reserve(factorSizes.size()); - for (auto [factorIndex, factorSize] : llvm::enumerate(factorSizes)) { - result.push_back(getCompatibleMajorShardingAxes( - projection, factorIndex, direction, factorSize, mesh, op, - conservativePropagation)); - } - return result; -} - UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings( ShardingProjection& projection, PropagationDirection direction, ArrayRef factorSizes, MeshAttr mesh, Operation* op, diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h index 853a4b33..70618791 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h @@ -29,9 +29,6 @@ limitations under the License. namespace mlir { namespace sdy { -using AxesPerFactor = SmallVector>; -using AxesPerFactorRef = ArrayRef>; - // A conservative strategy of propagating sharding axes along the factor. // // Refer to the documentation of `getCompatibleMajorShardingAxes` for the @@ -81,19 +78,12 @@ class BasicFactorPropagation : public FactorPropagation { // - Given factor shardings ["a":(1)2, "b"] and ["a":(1)4], returns // ["a":(1)2]. // - // TODO(b/350563653). Mark the following two methods as protected or private. - virtual SmallVector getCompatibleMajorShardingAxes( + // TODO(b/350563653). Mark the following method as protected or private. + SmallVector getCompatibleMajorShardingAxes( const ShardingProjection& projection, int64_t factorIndex, PropagationDirection direction, int64_t factorSize, MeshAttr mesh, Operation* op, bool conservativePropagation) const; - // Similar to `getCompatibleMajorShardingAxes`, but returns the compatible - // major axes for all factors. - virtual AxesPerFactor getCompatibleMajorShardingAxesForAllFactors( - const ShardingProjection& projection, PropagationDirection direction, - ArrayRef factorSizes, MeshAttr mesh, Operation* op, - bool conservativePropagation) const; - // Propagates the factor shardings in `projection`. UpdateTensorShardings propagateFactorShardings( ShardingProjection& projection, PropagationDirection direction, @@ -112,20 +102,18 @@ class BasicFactorPropagation : public FactorPropagation { const ShardingProjection& projection, int64_t factorIndex, PropagationDirection direction, Operation* op) const; - // Returns the largest prefix of `axisRef`, which does not overlap with (1) - // overflow axes, (2) existing sharding axes, and (3) potential new sharding - // axes in `newAxesPerFactor` for all other factors. + // Returns the largest prefix of `axisRef`, which does not overlap with + // sharding axes and overflow axes for all other factors. // // This function does not consider the conflicts within the factor itself, // which are considered in `compatiblePrefixNoConflictsWithinFactor`. The - // returned prefix can be overlapped with (1) overflow axes, (2) existing - // sharding axes, and (3) potential new sharding axes of the factor itself. + // returned prefix can be overlapped with sharding axes and overflow axes of + // the factor itself. // // Returns std::nullopt if the prefix does not exist. std::optional compatiblePrefixNoConflictsAcrossFactors( AxisRefAttr axisRef, const FactorIndexToSharding& factorIndexToSharding, - int64_t factorIndex, - AxesPerFactorRef newAxesPerFactor = AxesPerFactorRef()) const; + int64_t factorIndex) const; // Returns the largest compatible prefix of `axisRef` by removing conflicts // with `replicatedAxes` and `factorSharding`. diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc index 111372e8..0076ab70 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc @@ -15,6 +15,7 @@ limitations under the License. #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" +#include #include #include #include @@ -34,24 +35,31 @@ limitations under the License. namespace mlir { namespace sdy { -namespace { - -// Returns if `oldAxes` should be updated by `newAxes`. bool shouldUpdate(ArrayRef oldAxes, ArrayRef newAxes) { if (newAxes.empty()) { return false; } + if (oldAxes.empty()) { + return true; + } + + int64_t minSize = std::min(oldAxes.size(), newAxes.size()); + for (int64_t i = 0; i < minSize - 1; ++i) { + if (oldAxes[i] != newAxes[i]) { + return false; + } + } + if (newAxes.size() < oldAxes.size()) { return false; } if (newAxes.size() > oldAxes.size()) { - return true; + return oldAxes[minSize - 1].prefixOf(newAxes[minSize - 1]); } - return newAxes.back().strictlyContains(oldAxes.back()); -} -} // namespace + return oldAxes.back().strictPrefixOf(newAxes.back()); +} bool TensorFactorShardings::updateShardingAxes(int64_t factorIndex, ArrayRef newAxes) { diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection.h b/shardy/dialect/sdy/transforms/propagation/sharding_projection.h index 15c7241c..bac0091b 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection.h +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection.h @@ -27,6 +27,9 @@ limitations under the License. namespace mlir { namespace sdy { +// Returns true if the `oldAxes` is a strict prefix of `newAxes`, +bool shouldUpdate(ArrayRef oldAxes, ArrayRef newAxes); + // The axes along which a factor is sharded, and whether the factor can be // further sharded (unless it's fully sharded already). struct FactorSharding { @@ -40,6 +43,12 @@ struct FactorSharding { // We need to store these axes so that we can add them when projecting back to // dimension shardings. SmallVector overflowAxes; + + bool operator==(const FactorSharding& other) const { + return axisRefs == other.axisRefs && isClosed == other.isClosed && + isMinorMost == other.isMinorMost && + overflowAxes == other.overflowAxes; + } }; using FactorIndexToSharding = llvm::DenseMap; @@ -51,6 +60,11 @@ struct TensorFactorShardings { FactorIndexToSharding factorIndexToSharding; SmallVector replicatedAxes; + bool operator==(const TensorFactorShardings& other) const { + return factorIndexToSharding == other.factorIndexToSharding && + replicatedAxes == other.replicatedAxes; + } + // Updates the sharding axes of the given `factorIndex` to `newAxes` if // 1. this tensor is associated with that factor, and // 2. `newAxes` strictly contains existing axes. For example, ["a", "b"] @@ -119,6 +133,7 @@ class ShardingProjection { int64_t getNumOperands() const { return operands.size(); } int64_t getNumResults() const { return results.size(); } + int64_t getNumTensors() const { return getNumOperands() + getNumResults(); } ArrayRef getOperands() const { return operands; } ArrayRef getResults() const { return results; } @@ -130,6 +145,15 @@ class ShardingProjection { return results[resultNum]; } + bool updateOperandSharding(int64_t operandIndex, int64_t factorIndex, + ArrayRef newAxes) { + return operands[operandIndex].updateShardingAxes(factorIndex, newAxes); + } + bool updateResultSharding(int64_t resultIndex, int64_t factorIndex, + ArrayRef newAxes) { + return results[resultIndex].updateShardingAxes(factorIndex, newAxes); + } + // Updates the shardings of all tensors that are associated with // `factorIndex` to be `newAxes` for that factor. Returns two BitVectors // indicating whether the operands and results have been updated. @@ -149,6 +173,10 @@ class ShardingProjection { OpShardingRuleAttr shardingRule, MeshAttr mesh); + bool operator==(const ShardingProjection& other) const { + return operands == other.operands && results == other.results; + } + private: SmallVector operands; SmallVector results; diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc index c9ba503f..64d60bd6 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc @@ -636,6 +636,44 @@ TEST_F(ShardingProjectionUpdateShardingTest, DotGeneralSimple) { ElementsAre(AxisRefIs("d"), AxisRefIs("f"))))); } +//===----------------------------------------------------------------------===// +// Tests for shouldUpdate +//===----------------------------------------------------------------------===// + +class ShouldUpdateTest : public PropagationTestBase {}; + +TEST_F(ShouldUpdateTest, ShouldUpdateTest) { + // One of the input arguments is empty. + EXPECT_FALSE(shouldUpdate({}, {})); + EXPECT_FALSE(shouldUpdate({createAxis("a")}, {})); + EXPECT_TRUE(shouldUpdate({}, {createAxis("a")})); + + // The two input arguments are the same. + EXPECT_FALSE(shouldUpdate({createAxis("a")}, {createAxis("a")})); + SmallVector axes = {createAxis("a"), createSubAxis("b", 2, 4)}; + EXPECT_FALSE(shouldUpdate(axes, axes)); + + EXPECT_FALSE(shouldUpdate({createAxis("a")}, {createAxis("b")})); + EXPECT_FALSE(shouldUpdate({createAxis("a"), createAxis("b")}, + {createAxis("b"), createAxis("a")})); + EXPECT_FALSE( + shouldUpdate({createAxis("a"), createSubAxis("b", 2, 4)}, + {createAxis("a"), createAxis("b"), createAxis("c")})); + EXPECT_FALSE( + shouldUpdate({createAxis("a"), createAxis("b"), createAxis("c")}, + {createAxis("a"), createAxis("b"), createAxis("d")})); + + auto expectTrue = [&](ArrayRef oldAxes, + ArrayRef newAxes) { + EXPECT_TRUE(shouldUpdate(oldAxes, newAxes)); + EXPECT_FALSE(shouldUpdate(newAxes, oldAxes)); + }; + expectTrue({createAxis("a"), createAxis("b")}, + {createAxis("a"), createAxis("b"), createAxis("c")}); + expectTrue({createAxis("a"), createSubAxis("b", 1, 4)}, + {createAxis("a"), createAxis("b")}); +} + //===----------------------------------------------------------------------===// // Tests for TensorFactorShardings::createTensorShardingAttr // diff --git a/shardy/dialect/sdy/transforms/propagation/test/aggressive_propagation.mlir b/shardy/dialect/sdy/transforms/propagation/test/aggressive_propagation.mlir index c704c91e..22ef4733 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/aggressive_propagation.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/aggressive_propagation.mlir @@ -19,12 +19,26 @@ func.func @no_conflict(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mes return %1 : tensor<8x16xf32> } -// CHECK-LABEL: func @fake_conflict( +// CHECK-LABEL: func @fake_conflict_between_two_non_contracting_dims( +// CHECK-SAME: %arg0: tensor<256x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {?}]>}, +// CHECK-SAME: %arg1: tensor<128x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {?}]>}) +// CHECK-SAME: -> (tensor<256x128xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{?}, {"a", ?}]>}) { +func.func @fake_conflict_between_two_non_contracting_dims(%arg0: tensor<256x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {?}]>}, + %arg1: tensor<128x512xf32>) + -> (tensor<256x128xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{?}, {"a", ?}]>}) { + // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1 + // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{?}, {"a", ?}]>]>} + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{?}, {"a", ?}]>]>} : + (tensor<256x512xf32>, tensor<128x512xf32>) -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> +} + +// CHECK-LABEL: func @fake_conflict_between_contracting_and_non_contracting_dims( // CHECK-SAME: %arg0: tensor<256x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}, {"b", "c"}]>}, // CHECK-SAME: %arg1: tensor<128x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"c", "b"}, {"a"}]>}) // CHECK-SAME: -> (tensor<256x128xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {"c", "b", ?}]>}) { -func.func @fake_conflict(%arg0: tensor<256x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}, {"b", "c"}]>}, - %arg1: tensor<128x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"c", "b"}, {"a"}]>}) -> tensor<256x128xf32> { +func.func @fake_conflict_between_contracting_and_non_contracting_dims(%arg0: tensor<256x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}, {"b", "c"}]>}, + %arg1: tensor<128x512xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"c", "b"}, {"a"}]>}) -> tensor<256x128xf32> { // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1 // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2_c_2, [{"a", ?}, {"c", "b", ?}]>]>} %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [1] : @@ -32,14 +46,50 @@ func.func @fake_conflict(%arg0: tensor<256x512xf32> {sdy.sharding = #sdy.shardin return %0 : tensor<256x128xf32> } -// CHECK-LABEL: func @real_conflict( +// CHECK-LABEL: func @fake_conflict_closed_dims( +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {"b"}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{}, {"b", "c", ?}]>}) +// CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {"b", "c", ?}]>}) { +func.func @fake_conflict_closed_dims(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {"b"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{}, {"b", "c", ?}]>}) -> tensor<8x8xf32> { + // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 + // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2_c_2, [{"a", ?}, {"b", "c", ?}]>]>} + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @real_conflict_across_factors( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}, {?}]>}, // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{?}, {"a"}]>}) // CHECK-SAME: -> tensor<8x8xf32> { -func.func @real_conflict(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}, {?}]>}, - %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{?}, {"a"}]>}) -> tensor<8x8xf32> { +func.func @real_conflict_across_factors(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}, {?}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{?}, {"a"}]>}) -> tensor<8x8xf32> { + // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 + // CHECK-NOT: sdy.sharding + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @real_conflict_within_a_factor( +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"b", ?}, {}]>}) +// CHECK-SAME: -> tensor<8x8xf32> { +func.func @real_conflict_within_a_factor(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"b", ?}, {}]>}) -> tensor<8x8xf32> { // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 // CHECK-NOT: sdy.sharding %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } + +// CHECK-LABEL: func @real_and_fake_conflicts( +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"b", ?}, {"a", ?}]>}) +// CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{?}, {"a", ?}]>}) { +func.func @real_and_fake_conflicts(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}, {?}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"b", ?}, {"a", ?}]>}) -> tensor<8x8xf32> { + // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 + // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2_c_2, [{?}, {"a", ?}]>]>} + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_priority_propagation.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_priority_propagation.mlir index 5a996a24..b67f131f 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/op_priority_propagation.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/op_priority_propagation.mlir @@ -6,7 +6,7 @@ sdy.mesh @mesh = <"a"=2, "b"=2> // have been propagated first. // CHECK-LABEL: func @element_wise_over_dot_general( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32>) +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) { func.func @element_wise_over_dot_general(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, %arg1: tensor<8x8xf32>) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) { // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>} @@ -22,11 +22,11 @@ func.func @element_wise_over_dot_general(%arg0: tensor<8x8xf32> {sdy.sharding = // Same as `element_wise_over_dot_general` but the dot_general is the last op. // CHECK-LABEL: func @element_wise_over_dot_general_flipped_op_order( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32>) +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) { func.func @element_wise_over_dot_general_flipped_op_order(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, %arg1: tensor<8x8xf32>) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) { // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {?}]>]>} : tensor<8x8xf32> - // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>} : tensor<8x8xf32> // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot_general %[[ADD_1]], %[[ADD_2]], contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>} // CHECK-NEXT: return %[[DOT]] : tensor<8x8xf32> %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> @@ -41,7 +41,7 @@ func.func @element_wise_over_dot_general_flipped_op_order(%arg0: tensor<8x8xf32> // first instead. // CHECK-LABEL: func @sharding_constraint_propagated( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32>) +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}) { func.func @sharding_constraint_propagated(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>} diff --git a/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir b/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir index 108f7bc8..85bcb552 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir @@ -23,7 +23,7 @@ func.func @split_constants_different_sharding( // in the hierarchy, which is the user-priority propagation. // CHECK-LABEL: func @user_priorities( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, // CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) { diff --git a/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline_data_flow_edges.mlir b/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline_data_flow_edges.mlir index fcb0f99f..cee8fead 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline_data_flow_edges.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline_data_flow_edges.mlir @@ -172,7 +172,7 @@ func.func @case_multiple_results_different_sharding_conflicts( // CHECK-SAME: %arg0: tensor, // CHECK-SAME: %arg1: tensor<8xi64> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}]>}, // CHECK-SAME: %arg2: tensor<8xi64> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", "b"}]>} -// CHECK-SAME: -> (tensor<8xi64> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", ?}]>} +// CHECK-SAME: -> (tensor<8xi64> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a", "b", ?}]>} func.func @case_closed_sharding( %arg0: tensor, %arg1: tensor<8xi64> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2_c_2, [{"a"}]>}, @@ -182,7 +182,7 @@ func.func @case_closed_sharding( stablehlo.return %arg1 : tensor<8xi64> }, { stablehlo.return %arg2 : tensor<8xi64> - // CHECK: }) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2_c_2, [{"a", ?}]>]>} : + // CHECK: }) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2_c_2, [{"a", "b", ?}]>]>} : }) : (tensor) -> tensor<8xi64> return %0 : tensor<8xi64> } diff --git a/shardy/dialect/sdy/transforms/propagation/test/user_priority_propagation.mlir b/shardy/dialect/sdy/transforms/propagation/test/user_priority_propagation.mlir index ff0221cc..d0e375e9 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/user_priority_propagation.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/user_priority_propagation.mlir @@ -18,7 +18,7 @@ func.func @no_priorities(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-LABEL: func @skipped_priorities( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}]>}) { func.func @skipped_priorities(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}p4]>}, @@ -32,7 +32,7 @@ func.func @skipped_priorities(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi // CHECK-LABEL: func @arg_lower_priority_than_return_value( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, // CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) { @@ -50,9 +50,9 @@ func.func @arg_lower_priority_than_return_value( // CHECK-LABEL: func @arg_lower_priority_than_return_value_with_replicated( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, -// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) +// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) { func.func @arg_lower_priority_than_return_value_with_replicated( %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1, {"b"}p1]>}, @@ -71,7 +71,7 @@ func.func @arg_lower_priority_than_return_value_with_replicated( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, -// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>}) +// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) { func.func @arg_higher_priority_than_return_value( %arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0, {"b"}p0]>}, @@ -107,7 +107,7 @@ func.func @result_lower_priority_than_arg( // CHECK-LABEL: func @result_higher_priority_than_arg( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, // CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}) { @@ -126,7 +126,7 @@ func.func @result_higher_priority_than_arg( // CHECK-LABEL: func @dim_with_lower_priority_gets_further_sharded_by_higher( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b", "a", ?}, {}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32>, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}, // CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b", "a", ?}, {?}]>}, // CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) func.func @dim_with_lower_priority_gets_further_sharded_by_higher( @@ -145,7 +145,7 @@ func.func @dim_with_lower_priority_gets_further_sharded_by_higher( // CHECK-LABEL: func @different_priorities_with_closed_empty_dim( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, -// CHECK-SAME: %arg2: tensor<8x8xf32>, +// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, // CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) { func.func @different_priorities_with_closed_empty_dim( @@ -163,7 +163,7 @@ func.func @different_priorities_with_closed_empty_dim( // CHECK-LABEL: func @open_empty_dim_with_priority( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}, -// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, +// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}]>}, // CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c"}]>}) // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}]>}) { func.func @open_empty_dim_with_priority( @@ -199,7 +199,8 @@ func.func @different_priorities_from_args( // CHECK-LABEL: func @different_priorities_from_ops( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32>, %arg2: tensor<8x16xf32>) +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}, +// CHECK-SAME: %arg2: tensor<8x16xf32>) // CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}) { func.func @different_priorities_from_ops(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { @@ -238,7 +239,7 @@ func.func @propagate_to_multi_result_op_with_priorities( // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK-NEXT: %[[REDUCE:.*]]:2 = stablehlo.reduce(%arg0 init: %[[CONST]]), (%arg1 init: %[[CONST]]) across dimensions = [1] // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>, <@mesh, [{?}, {"a", ?}]>]>} - // CHECK: stablehlo.add %[[REDUCE]]#1, %arg2 : + // CHECK: stablehlo.add %[[REDUCE]]#1, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>} : %0 = stablehlo.constant dense<0.000000e+00> : tensor %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) @@ -252,7 +253,8 @@ func.func @propagate_to_multi_result_op_with_priorities( } // CHECK-LABEL: func @propagate_from_multi_result_op_with_priorities( -// CHECK-SAME: %arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) +// CHECK-SAME: %arg0: tensor<4x64x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {?}, {"b", ?}]>}, +// CHECK-SAME: %arg1: tensor<4x64x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {?}, {"b", ?}]>}) func.func @propagate_from_multi_result_op_with_priorities( %arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> @@ -320,7 +322,7 @@ func.func @user_based_and_op_based( %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}p1, {?}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}p0, {"b", ?}p0]>}) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b", ?}p0, {?}]>}) { - // CHECK: %[[ADD_0:.*]] = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>} : tensor<8x8xf32> + // CHECK: %[[ADD_0:.*]] = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b", ?}, {"a", ?}]>]>} : tensor<8x8xf32> // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot_general %[[ADD_0]], %arg1, contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b", ?}, {?}]>]>} : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %[[DOT]], %[[DOT]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b", ?}, {?}]>]>} : tensor<8x8xf32> // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %[[ADD_1]], %[[ADD_1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b", ?}, {?}]>]>} : tensor<8x8xf32>