Skip to content

Commit 7a3f6c1

Browse files
ZixuanJiangcopybara-github
authored andcommitted
Improve aggressive factor propagation strategy in Shardy. There are two main differences from BasicFactorPropagation.
### Difference 1 `BasicFactorPropagation` propagates the same sharding axes to all the tensors along a factor. This strategy can propagate different sharding axes to different tensors. For example, Tensors T0, T1, T2 contains Factor F0. T0/F0 is already sharded along ["a", "b"], and "b" is already used by T2 ("b" can be explicitly replicated, or it is used to shard another factor). `BasicFactorPropagation` propagates ["a"] to both T1/F0 and T2/F0, while this strategy propagates ["a", "b"] to T1/F0 and ["a"] to T2/F0, respectively. ### Difference 2 `BasicFactorPropagation` is conservative in terms of conflicts across factors. The overlapped axis between factors cannot be propagated. This strategy is more aggressive by allowing the overlapped axis being propagated along different factors if there is no overlapped axis in the result shardings. PiperOrigin-RevId: 655675663
1 parent 838d1aa commit 7a3f6c1

14 files changed

+441
-386
lines changed

shardy/dialect/sdy/transforms/propagation/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,10 @@ cc_test(
250250
srcs = ["aggressive_factor_propagation_test.cc"],
251251
deps = [
252252
":aggressive_factor_propagation",
253-
":basic_factor_propagation",
253+
":factor_propagation",
254254
":sharding_projection",
255255
":testing_utils",
256+
":utils",
256257
"//shardy/dialect/sdy/ir:dialect",
257258
"@com_google_googletest//:gtest_main",
258259
"@llvm-project//llvm:Support",

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc

+62-109
Original file line numberDiff line numberDiff line change
@@ -23,146 +23,99 @@ limitations under the License.
2323
#include "mlir/IR/Value.h"
2424
#include "mlir/Support/LLVM.h"
2525
#include "shardy/dialect/sdy/ir/dialect.h"
26-
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
2726
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
2827
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
2928

3029
namespace mlir {
3130
namespace sdy {
3231

33-
AxesPerFactor
34-
AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors(
35-
const ShardingProjection& projection, PropagationDirection direction,
32+
namespace {
33+
34+
bool updateTensorSharding(ShardingProjection& projection, int64_t tensorIndex,
35+
int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes) {
36+
if (tensorIndex < projection.getNumOperands()) {
37+
return projection.updateOperandSharding(tensorIndex, factorIndex, newAxes);
38+
}
39+
return projection.updateResultSharding(
40+
tensorIndex - projection.getNumOperands(), factorIndex, newAxes);
41+
}
42+
43+
} // namespace
44+
45+
UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
46+
ShardingProjection& projection, PropagationDirection direction,
3647
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
3748
bool conservativePropagation) const {
49+
UpdateTensorShardings result{
50+
.updateOperands = BitVector(projection.getNumOperands()),
51+
.updateResults = BitVector(projection.getNumResults())};
3852
if (direction == PropagationDirection::NONE) {
39-
return AxesPerFactor(factorSizes.size());
53+
return result;
4054
}
4155

42-
// Finds the compatible major axes ignoring conflicts.
43-
AxesPerFactor result;
44-
result.reserve(factorSizes.size());
56+
// Find the compatible major axes ignoring conflicts.
57+
SmallVector<SmallVector<AxisRefAttr>> axesPerFactor;
58+
axesPerFactor.reserve(factorSizes.size());
59+
bool allElementsAreEmpty = true;
4560
for (int64_t i = 0; i < factorSizes.size(); ++i) {
46-
result.push_back(getCompatibleMajorAxes(projection, i, direction, op));
61+
SmallVector<AxisRefAttr>& axes = axesPerFactor.emplace_back(
62+
getCompatibleMajorAxes(projection, i, direction, op));
63+
if (!axes.empty()) {
64+
allElementsAreEmpty = false;
65+
}
66+
}
67+
if (allElementsAreEmpty) {
68+
return result;
4769
}
4870

49-
// Removes the conflicts within every single factor. This strategy and
50-
// `BasicFactorPropagation` handles conflicts within a factor in the same way.
51-
for (const TensorFactorShardings& tensorFactorShardings :
52-
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
53-
projection.getResults())) {
54-
for (const auto& [factorIndex, factorSharding] :
55-
tensorFactorShardings.factorIndexToSharding) {
71+
// The propagation on each tensor is independent. This strategy can propagate
72+
// different shardings to different tensors along the same factor. Examples
73+
// are provided in the docstring of this class.
74+
for (const auto& [tensorIndex, tensorFactorShardings] :
75+
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
76+
projection.getOperands(), projection.getResults()))) {
77+
// Propagate the axes got in Step 1, and resolve conflicts within a factor.
78+
FactorIndexToSharding newSharding =
79+
tensorFactorShardings.factorIndexToSharding;
80+
BitVector factorUpdated(factorSizes.size());
81+
for (auto& [factorIndex, factorSharding] : newSharding) {
82+
SmallVector<AxisRefAttr> newAxes = axesPerFactor[factorIndex];
5683
truncateAxesByRemovingConflicts(
57-
result[factorIndex],
84+
newAxes,
5885
[&, factorIndex = factorIndex, &factorSharding = factorSharding](
5986
AxisRefAttr axisRef, int64_t shardedSize) {
6087
return compatiblePrefixNoConflictsWithinFactor(
6188
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
6289
shardedSize, factorSizes[factorIndex]);
6390
},
6491
mesh, conservativePropagation);
92+
if (shouldUpdate(factorSharding.axisRefs, newAxes)) {
93+
factorSharding.axisRefs = newAxes;
94+
factorUpdated.set(factorIndex);
95+
}
6596
}
66-
}
6797

68-
// Removes the conflicts across factors, where this strategy and
69-
// `BasicFactorPropagation` diverge.
70-
//
71-
// With `BasicFactorPropagation`, the compatible axes of a factor Fi cannot
72-
// overlap with the existing sharding axes or the overflow axes related to all
73-
// other factors. This criterion is considered for all tensors, no matter if
74-
// Fi is mapped to the tensor or not. The table below shows the criterion:
75-
//
76-
// existing sharding axes & overflow axes new sharding axes
77-
// factor in tensor remove overlap -
78-
// factor not in tensor remove overlap -
79-
//
80-
// On the contrary, `AggressiveFactorPropagation` has the following criterion:
81-
//
82-
// existing sharding axes & overflow axes new sharding axes
83-
// factor in tensor remove overlap remove overlap
84-
// factor not in tensor - -
85-
//
86-
// There are two differences:
87-
//
88-
// 1. `BasicFactorPropagation` removes the overlap between the compatible axes
89-
// of a factor Fi with the existing sharding axes and overflow axes in a
90-
// tensor Tj even if Fi is not in Tj. `AggressiveFactorPropagation` does not
91-
// remove this overlap if Fi is not in Tj. `BasicFactorPropagation` is too
92-
// strict, since we cannot propagate sharding axes to Tj along Fi.
93-
//
94-
// `AggressiveFactorPropagation` cannot handle the following case if we only
95-
// have difference #1. `-` means that the factor is not mapped to the tensor.
96-
// After removing conflicts within factors, we will propagate "x" to T2 along
97-
// F0 and F1 at the same time, which induces a conflict. To resolve this
98-
// conflict, we have difference #2.
99-
//
100-
// F0 F1
101-
// T0 "x" -
102-
// T1 - "x"
103-
// T2 ? ?
104-
//
105-
// 2. `AggressiveFactorPropagation` removes the overlap between compatible
106-
// axes of a factor Fi with the potential new sharding axes of other factors
107-
// in Tj if Fi is in Tj. Thus, it is safe to propagate the axes to Tj along Fi
108-
// without conflicts with other factors. In the example, we will not propagate
109-
// "x" along F0 or F1 since their potential new sharding axes overlap.
110-
//
111-
// The potential new sharding axes are saved in `resultSnapshot`. It is a hard
112-
// copy since we need to handle the following case.
113-
//
114-
// F0 F1 F2
115-
// T0 "x" - -
116-
// T1 - "x" -
117-
// T2 - - "x"
118-
// T3 ? ? ?
119-
//
120-
// The `result` and `resultSnapshot` is [["x"], ["x"], ["x"]] before removing
121-
// conflicts across factors. After removing conflicts between F0/F1 and other
122-
// factors, `result` is [[], [], ["x"]]. When we remove conflicts between F2
123-
// and other factors, if we use `result` as the potential new sharding axes,
124-
// we will not remove "x" for F2 because it is no longer present in 'result'
125-
// for F0 and F1. We have to use `resultSnapshot` to save the potential new
126-
// sharding axes and remove "x" for F2.
127-
const AxesPerFactor resultSnapshot = result;
128-
for (const TensorFactorShardings& tensorFactorSharding :
129-
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
130-
projection.getResults())) {
131-
for (const auto& [factorIndex, factorSharding] :
132-
tensorFactorSharding.factorIndexToSharding) {
98+
// Resolve conflicts (overlapping sharding axes) between factors.
99+
bool tensorUpdated = false;
100+
for (const int64_t factorIndex : factorUpdated.set_bits()) {
101+
SmallVector<AxisRefAttr> newAxes = newSharding[factorIndex].axisRefs;
133102
truncateAxesByRemovingConflicts(
134-
result[factorIndex],
103+
newAxes,
135104
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
136105
return compatiblePrefixNoConflictsAcrossFactors(
137-
axisRef, tensorFactorSharding.factorIndexToSharding,
138-
factorIndex, resultSnapshot);
106+
axisRef, newSharding, factorIndex);
139107
},
140108
mesh, conservativePropagation);
109+
tensorUpdated |=
110+
updateTensorSharding(projection, tensorIndex, factorIndex, newAxes);
141111
}
142-
}
143112

144-
return result;
145-
}
146-
147-
UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
148-
ShardingProjection& projection, PropagationDirection direction,
149-
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
150-
bool conservativePropagation) const {
151-
UpdateTensorShardings result{
152-
.updateOperands = BitVector(projection.getNumOperands()),
153-
.updateResults = BitVector(projection.getNumResults())};
154-
155-
// We get the compatible major sharding axes for all factors.
156-
AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors(
157-
projection, direction, factorSizes, mesh, op, conservativePropagation);
158-
159-
for (auto [factorIndex, axesToPropagate] : llvm::enumerate(axesPerFactor)) {
160-
// Update all shardings along this factor if possible.
161-
auto [updateOperandForFactor, updateResultForFactor] =
162-
projection.updateSharding(factorIndex, axesToPropagate);
163-
164-
result.updateOperands |= updateOperandForFactor;
165-
result.updateResults |= updateResultForFactor;
113+
if (tensorIndex < projection.getNumOperands()) {
114+
result.updateOperands[tensorIndex] = tensorUpdated;
115+
} else {
116+
result.updateResults[tensorIndex - projection.getNumOperands()] =
117+
tensorUpdated;
118+
}
166119
}
167120

168121
return result;

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h

+35-22
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,53 @@ limitations under the License.
2222
#include "mlir/Support/LLVM.h"
2323
#include "shardy/dialect/sdy/ir/dialect.h"
2424
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
25+
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
2526
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
2627

2728
namespace mlir {
2829
namespace sdy {
2930

30-
// An aggressive strategy of propagating sharding axes along factors.
31+
// An aggressive strategy of propagating sharding axes along factors. There are
32+
// two main differences from `BasicFactorPropagation`.
3133
//
32-
// This strategy is the same as `BasicFactorPropagation` on the conflicts within
33-
// a factor. They are different on the conflicts across factors.
34+
// `BasicFactorPropagation` propagates the same sharding axes to all tensors
35+
// along a factor. This strategy can propagate different sharding axes to
36+
// different tensors along the same factor. For example, Tensors T0, T1, T2
37+
// contain Factor F0. T0/F0 is already sharded along ["a", "b"], and "b" is
38+
// already used by T2 ("b" can be explicitly replicated, or it is used to shard
39+
// another factor). `BasicFactorPropagation` propagates ["a"] to both T1/F0 and
40+
// T2/F0, while this strategy propagates ["a", "b"] to T1/F0 and ["a"] to T2/F0,
41+
// respectively. If T2/F0 is closed, `BasicFactorPropagation` propagates
42+
// nothing, while this strategy propagates nothing to T2/F0 and still propagates
43+
// ["a", "b"] to T1/F0.
3444
//
35-
// `BasicFactorPropagation` considers the conflicts across factors with a strict
36-
// criterion. The result cannot overlap with the sharded axes or overflow axes
37-
// related to all other factors. This aggressive strategy ignores "fake
38-
// conflicts", which are propagation choices that can co-exist. This aggressive
39-
// strategy ensures that the resultant axes can be propagated to all tensors
40-
// containing the factor. Several examples of fake conflicts:
45+
// `BasicFactorPropagation` is conservative in terms of conflicts across
46+
// factors. The overlapped axis between factors cannot be propagated. This
47+
// strategy is more aggressive by allowing the overlapped axis being propagated
48+
// along different factors if there is no overlapped axis in the result
49+
// shardings.
4150
//
42-
// 1. An axis is in factors Fi and Fj. If it is infeasible to propagate that
43-
// axis along factor Fi, we may propagate that axis along factor Fj if all the
44-
// destination tensors have not used that axis.
51+
// Let us take C = dot(A, B) as an example. F0 is the factor corresponding to a
52+
// non-contracting dimension of A. F1 corresponds to a non-contracting dimension
53+
// of B. F2 corresponds to a contracting dimension. "-" means that the tensor
54+
// does not contain the factor.
4555
//
46-
// 2. Two factors Fi and Fj do not co-exist in any tensor, so they never
47-
// interfere with each other. If Fi and Fj are sharded along the same axis, we
48-
// can propagate that axis along both factors.
56+
// F0 F1 F2
57+
// A "a" -
58+
// B -
59+
// C "a" -
60+
// Case 1. Fake conflict. `BasicFactorPropagation` propagates nothing, while
61+
// this strategy propagates "a" to B/F1.
4962
//
50-
// Although fake conflicts can co-exist without inference, we may still need to
51-
// all-gather some tensors.
63+
// F0 F1 F2
64+
// A "a" -
65+
// B - "a"
66+
// C -
67+
// Case 2. Real conflict. Both `BasicFactorPropagation` and this strategy
68+
// propagate nothing. We can propagate "a" to C/F0 or C/F1, which is illegal
69+
// since "a" cannot be used twice in C.
5270
class AggressiveFactorPropagation : public BasicFactorPropagation {
5371
public:
54-
AxesPerFactor getCompatibleMajorShardingAxesForAllFactors(
55-
const ShardingProjection& projection, PropagationDirection direction,
56-
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
57-
bool conservativePropagation) const override;
58-
5972
UpdateTensorShardings propagateFactorShardings(
6073
ShardingProjection& projection, PropagationDirection direction,
6174
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,

0 commit comments

Comments
 (0)