Skip to content

Commit c7c5d33

Browse files
ZixuanJiangcopybara-github
authored andcommitted
Improve aggressive factor propagation strategy in Shardy. There are two main differences from BasicFactorPropagation.
### Difference 1 `BasicFactorPropagation` propagates the same sharding axes to all the tensors along a factor. This strategy can propagate different sharding axes to different tensors. For example, Tensors T0, T1, T2 contains Factor F0. T0/F0 is already sharded along ["a", "b"], and "b" is already used by T2 ("b" can explicitly replicated, or it is used to shard another factor). `BasicFactorPropagation` can only propagate ["a"] to both T1/F0 and T2/F0, while this strategy can propagate ["a", "b"] to T1/F0 and ["a"] to T2/F0, respectively. ### Difference 2 `BasicFactorPropagation` is conservative in terms of conflicts across factors. If an axis (or sub-axis) appears in two factor shardings, it cannot be propagated. This strategy is more aggressive by allowing the same axis (or sub-axis) being propagated along different factors if the result shardings are legal (a tensor can only be sharded by one axis at most once). PiperOrigin-RevId: 655675663
1 parent d6fa764 commit c7c5d33

12 files changed

+314
-376
lines changed

shardy/dialect/sdy/transforms/propagation/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,10 @@ cc_test(
250250
srcs = ["aggressive_factor_propagation_test.cc"],
251251
deps = [
252252
":aggressive_factor_propagation",
253-
":basic_factor_propagation",
253+
":factor_propagation",
254254
":sharding_projection",
255255
":testing_utils",
256+
":utils",
256257
"//shardy/dialect/sdy/ir:dialect",
257258
"@com_google_googletest//:gtest_main",
258259
"@llvm-project//llvm:Support",

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc

+52-106
Original file line numberDiff line numberDiff line change
@@ -23,146 +23,92 @@ limitations under the License.
2323
#include "mlir/IR/Value.h"
2424
#include "mlir/Support/LLVM.h"
2525
#include "shardy/dialect/sdy/ir/dialect.h"
26-
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
2726
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
2827
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
2928

3029
namespace mlir {
3130
namespace sdy {
3231

33-
AxesPerFactor
34-
AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors(
35-
const ShardingProjection& projection, PropagationDirection direction,
32+
namespace {
33+
34+
bool updateTensorSharding(ShardingProjection& projection, int64_t tensorIndex,
35+
int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes) {
36+
if (tensorIndex < projection.getNumOperands()) {
37+
return projection.updateOperandSharding(tensorIndex, factorIndex, newAxes);
38+
}
39+
return projection.updateResultSharding(
40+
tensorIndex - projection.getNumOperands(), factorIndex, newAxes);
41+
}
42+
43+
} // namespace
44+
45+
UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
46+
ShardingProjection& projection, PropagationDirection direction,
3647
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
3748
bool conservativePropagation) const {
49+
UpdateTensorShardings result{
50+
.updateOperands = BitVector(projection.getNumOperands()),
51+
.updateResults = BitVector(projection.getNumResults())};
3852
if (direction == PropagationDirection::NONE) {
39-
return AxesPerFactor(factorSizes.size());
53+
return result;
4054
}
4155

42-
// Finds the compatible major axes ignoring conflicts.
43-
AxesPerFactor result;
44-
result.reserve(factorSizes.size());
56+
// Step 1. Find the compatible major axes ignoring conflicts.
57+
SmallVector<SmallVector<AxisRefAttr>> axesPerFactor;
58+
axesPerFactor.reserve(factorSizes.size());
4559
for (int64_t i = 0; i < factorSizes.size(); ++i) {
46-
result.push_back(getCompatibleMajorAxes(projection, i, direction, op));
60+
axesPerFactor.push_back(
61+
getCompatibleMajorAxes(projection, i, direction, op));
4762
}
4863

49-
// Removes the conflicts within every single factor. This strategy and
50-
// `BasicFactorPropagation` handles conflicts within a factor in the same way.
64+
// Step 2. Propagate the axes got in Step 1, considering conflicts within the
65+
// factor, ignoring the conflicts between factors.
66+
SmallVector<FactorIndexToSharding> newShardings;
67+
newShardings.reserve(projection.getNumOperandsAndResults());
5168
for (const TensorFactorShardings& tensorFactorShardings :
5269
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
5370
projection.getResults())) {
71+
newShardings.push_back(tensorFactorShardings.factorIndexToSharding);
5472
for (const auto& [factorIndex, factorSharding] :
5573
tensorFactorShardings.factorIndexToSharding) {
56-
truncateAxesByRemovingConflicts(
57-
result[factorIndex],
74+
if (!shouldUpdate(factorSharding.axisRefs, axesPerFactor[factorIndex])) {
75+
continue;
76+
}
77+
SmallVector<AxisRefAttr> newAxes = truncateAxesByRemovingConflicts(
78+
axesPerFactor[factorIndex],
5879
[&, factorIndex = factorIndex, &factorSharding = factorSharding](
5980
AxisRefAttr axisRef, int64_t shardedSize) {
6081
return compatiblePrefixNoConflictsWithinFactor(
6182
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
6283
shardedSize, factorSizes[factorIndex]);
6384
},
6485
mesh, conservativePropagation);
86+
newShardings.back()[factorIndex].axisRefs = newAxes;
6587
}
6688
}
6789

68-
// Removes the conflicts across factors, where this strategy and
69-
// `BasicFactorPropagation` diverge.
70-
//
71-
// With `BasicFactorPropagation`, the compatible axes of a factor Fi cannot
72-
// overlap with the existing sharding axes or the overflow axes related to all
73-
// other factors. This criterion is considered for all tensors, no matter if
74-
// Fi is mapped to the tensor or not. The table below shows the criterion:
75-
//
76-
// existing sharding axes & overflow axes new sharding axes
77-
// factor in tensor remove overlap -
78-
// factor not in tensor remove overlap -
79-
//
80-
// On the contrary, `AggressiveFactorPropagation` has the following criterion:
81-
//
82-
// existing sharding axes & overflow axes new sharding axes
83-
// factor in tensor remove overlap remove overlap
84-
// factor not in tensor - -
85-
//
86-
// There are two differences:
87-
//
88-
// 1. `BasicFactorPropagation` removes the overlap between the compatible axes
89-
// of a factor Fi with the existing sharding axes and overflow axes in a
90-
// tensor Tj even if Fi is not in Tj. `AggressiveFactorPropagation` does not
91-
// remove this overlap if Fi is not in Tj. `BasicFactorPropagation` is too
92-
// strict, since we cannot propagate sharding axes to Tj along Fi.
93-
//
94-
// `AggressiveFactorPropagation` cannot handle the following case if we only
95-
// have difference #1. `-` means that the factor is not mapped to the tensor.
96-
// After removing conflicts within factors, we will propagate "x" to T2 along
97-
// F0 and F1 at the same time, which induces a conflict. To resolve this
98-
// conflict, we have difference #2.
99-
//
100-
// F0 F1
101-
// T0 "x" -
102-
// T1 - "x"
103-
// T2 ? ?
104-
//
105-
// 2. `AggressiveFactorPropagation` removes the overlap between compatible
106-
// axes of a factor Fi with the potential new sharding axes of other factors
107-
// in Tj if Fi is in Tj. Thus, it is safe to propagate the axes to Tj along Fi
108-
// without conflicts with other factors. In the example, we will not propagate
109-
// "x" along F0 or F1 since their potential new sharding axes overlap.
110-
//
111-
// The potential new sharding axes are saved in `resultSnapshot`. It is a hard
112-
// copy since we need to handle the following case.
113-
//
114-
// F0 F1 F2
115-
// T0 "x" - -
116-
// T1 - "x" -
117-
// T2 - - "x"
118-
// T3 ? ? ?
119-
//
120-
// The `result` and `resultSnapshot` is [["x"], ["x"], ["x"]] before removing
121-
// conflicts across factors. After removing conflicts between F0/F1 and other
122-
// factors, `result` is [[], [], ["x"]]. When we remove conflicts between F2
123-
// and other factors, if we use `result` as the potential new sharding axes,
124-
// we will not remove "x" for F2 because it is no longer present in 'result'
125-
// for F0 and F1. We have to use `resultSnapshot` to save the potential new
126-
// sharding axes and remove "x" for F2.
127-
const AxesPerFactor resultSnapshot = result;
128-
for (const TensorFactorShardings& tensorFactorSharding :
129-
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
130-
projection.getResults())) {
131-
for (const auto& [factorIndex, factorSharding] :
132-
tensorFactorSharding.factorIndexToSharding) {
133-
truncateAxesByRemovingConflicts(
134-
result[factorIndex],
90+
// Step 3. Remove conflicts (overlapping sharding axes) between factors in
91+
// `newShardings`. Update (1) shardings in `projection` and (2) `result` based
92+
// on new sharding axes.
93+
for (const auto& [tensorIndex, newSharding] : llvm::enumerate(newShardings)) {
94+
bool tensorUpdated = false;
95+
for (const auto& [factorIndex, factorSharding] : newSharding) {
96+
SmallVector<AxisRefAttr> newAxes = truncateAxesByRemovingConflicts(
97+
factorSharding.axisRefs,
13598
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
13699
return compatiblePrefixNoConflictsAcrossFactors(
137-
axisRef, tensorFactorSharding.factorIndexToSharding,
138-
factorIndex, resultSnapshot);
100+
axisRef, newSharding, factorIndex);
139101
},
140102
mesh, conservativePropagation);
103+
tensorUpdated |=
104+
updateTensorSharding(projection, tensorIndex, factorIndex, newAxes);
105+
}
106+
if (tensorIndex < projection.getNumOperands()) {
107+
result.updateOperands[tensorIndex] = tensorUpdated;
108+
} else {
109+
result.updateResults[tensorIndex - projection.getNumOperands()] =
110+
tensorUpdated;
141111
}
142-
}
143-
144-
return result;
145-
}
146-
147-
UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
148-
ShardingProjection& projection, PropagationDirection direction,
149-
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
150-
bool conservativePropagation) const {
151-
UpdateTensorShardings result{
152-
.updateOperands = BitVector(projection.getNumOperands()),
153-
.updateResults = BitVector(projection.getNumResults())};
154-
155-
// We get the compatible major sharding axes for all factors.
156-
AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors(
157-
projection, direction, factorSizes, mesh, op, conservativePropagation);
158-
159-
for (auto [factorIndex, axesToPropagate] : llvm::enumerate(axesPerFactor)) {
160-
// Update all shardings along this factor if possible.
161-
auto [updateOperandForFactor, updateResultForFactor] =
162-
projection.updateSharding(factorIndex, axesToPropagate);
163-
164-
result.updateOperands |= updateOperandForFactor;
165-
result.updateResults |= updateResultForFactor;
166112
}
167113

168114
return result;

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h

+32-22
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,50 @@ limitations under the License.
2222
#include "mlir/Support/LLVM.h"
2323
#include "shardy/dialect/sdy/ir/dialect.h"
2424
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
25+
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
2526
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
2627

2728
namespace mlir {
2829
namespace sdy {
2930

30-
// An aggressive strategy of propagating sharding axes along factors.
31+
// An aggressive strategy of propagating sharding axes along factors. There are
32+
// two main differences from `BasicFactorPropagation`.
3133
//
32-
// This strategy is the same as `BasicFactorPropagation` on the conflicts within
33-
// a factor. They are different on the conflicts across factors.
34+
// `BasicFactorPropagation` propagates the same sharding axes to all the tensors
35+
// along a factor. This strategy can propagate different sharding axes to
36+
// different tensors. For example, Tensors T0, T1, T2 contains Factor F0. T0/F0
37+
// is already sharded along ["a", "b"], and "b" is already used by T2 ("b" can
38+
// explicitly replicated, or it is used to shard another factor).
39+
// `BasicFactorPropagation` can only propagate ["a"] to both T1/F0 and T2/F0,
40+
// while this strategy can propagate ["a", "b"] to T1/F0 and ["a"] to T2/F0.
3441
//
35-
// `BasicFactorPropagation` considers the conflicts across factors with a strict
36-
// criterion. The result cannot overlap with the sharded axes or overflow axes
37-
// related to all other factors. This aggressive strategy ignores "fake
38-
// conflicts", which are propagation choices that can co-exist. This aggressive
39-
// strategy ensures that the resultant axes can be propagated to all tensors
40-
// containing the factor. Several examples of fake conflicts:
42+
// `BasicFactorPropagation` is conservative in terms of conflicts across
43+
// factors. If an axis (or sub-axis) appears in two factor shardings, it cannot
44+
// be propagated. This strategy is more aggressive by allowing the same axis (or
45+
// sub-axis) along different factors if the result shardings are legal (all the
46+
// tensors can only be sharded by one axis at most once).
4147
//
42-
// 1. An axis is in factors Fi and Fj. If it is infeasible to propagate that
43-
// axis along factor Fi, we may propagate that axis along factor Fj if all the
44-
// destination tensors have not used that axis.
48+
// Let us take C = dot(A, B) as an example. F0 is the factor corresponding to a
49+
// non-contracting dimension of A. F1 corresponds to a non-contracting dimension
50+
// of B. F2 corresponds to a contracting dimension. "-" means that the tensor
51+
// does not contain the factor.
4552
//
46-
// 2. Two factors Fi and Fj do not co-exist in any tensor, so they never
47-
// interfere with each other. If Fi and Fj are sharded along the same axis, we
48-
// can propagate that axis along both factors.
53+
// F0 F1 F2
54+
// A "a" -
55+
// B -
56+
// C "a" -
57+
// Case 1. Fake conflict. `BasicFactorPropagation` propagates nothing, while
58+
// this strategy propagates "a" to B/F1.
4959
//
50-
// Although fake conflicts can co-exist without inference, we may still need to
51-
// all-gather some tensors.
60+
// F0 F1 F2
61+
// A "a" -
62+
// B - "a"
63+
// C -
64+
// Case 2. Real conflict. Both `BasicFactorPropagation` and this strategy
65+
// propagate nothing. We can propagate "a" to C/F0 or C/F1, which is illegal
66+
// since "a" cannot be used twice in C.
5267
class AggressiveFactorPropagation : public BasicFactorPropagation {
5368
public:
54-
AxesPerFactor getCompatibleMajorShardingAxesForAllFactors(
55-
const ShardingProjection& projection, PropagationDirection direction,
56-
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
57-
bool conservativePropagation) const override;
58-
5969
UpdateTensorShardings propagateFactorShardings(
6070
ShardingProjection& projection, PropagationDirection direction,
6171
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,

0 commit comments

Comments
 (0)