@@ -23,146 +23,92 @@ limitations under the License.
23
23
#include " mlir/IR/Value.h"
24
24
#include " mlir/Support/LLVM.h"
25
25
#include " shardy/dialect/sdy/ir/dialect.h"
26
- #include " shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
27
26
#include " shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
28
27
#include " shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
29
28
30
29
namespace mlir {
31
30
namespace sdy {
32
31
33
- AxesPerFactor
34
- AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors (
35
- const ShardingProjection& projection, PropagationDirection direction,
32
+ namespace {
33
+
34
+ bool updateTensorSharding (ShardingProjection& projection, int64_t tensorIndex,
35
+ int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes) {
36
+ if (tensorIndex < projection.getNumOperands ()) {
37
+ return projection.updateOperandSharding (tensorIndex, factorIndex, newAxes);
38
+ }
39
+ return projection.updateResultSharding (
40
+ tensorIndex - projection.getNumOperands (), factorIndex, newAxes);
41
+ }
42
+
43
+ } // namespace
44
+
45
+ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings (
46
+ ShardingProjection& projection, PropagationDirection direction,
36
47
ArrayRef<int64_t > factorSizes, MeshAttr mesh, Operation* op,
37
48
bool conservativePropagation) const {
49
+ UpdateTensorShardings result{
50
+ .updateOperands = BitVector (projection.getNumOperands ()),
51
+ .updateResults = BitVector (projection.getNumResults ())};
38
52
if (direction == PropagationDirection::NONE) {
39
- return AxesPerFactor (factorSizes. size ()) ;
53
+ return result ;
40
54
}
41
55
42
- // Finds the compatible major axes ignoring conflicts.
43
- AxesPerFactor result ;
44
- result .reserve (factorSizes.size ());
56
+ // Step 1. Find the compatible major axes ignoring conflicts.
57
+ SmallVector<SmallVector<AxisRefAttr>> axesPerFactor ;
58
+ axesPerFactor .reserve (factorSizes.size ());
45
59
for (int64_t i = 0 ; i < factorSizes.size (); ++i) {
46
- result.push_back (getCompatibleMajorAxes (projection, i, direction, op));
60
+ axesPerFactor.push_back (
61
+ getCompatibleMajorAxes (projection, i, direction, op));
47
62
}
48
63
49
- // Removes the conflicts within every single factor. This strategy and
50
- // `BasicFactorPropagation` handles conflicts within a factor in the same way.
64
+ // Step 2. Propagate the axes got in Step 1, considering conflicts within the
65
+ // factor, ignoring the conflicts between factors.
66
+ SmallVector<FactorIndexToSharding> newShardings;
67
+ newShardings.reserve (projection.getNumOperandsAndResults ());
51
68
for (const TensorFactorShardings& tensorFactorShardings :
52
69
llvm::concat<const TensorFactorShardings>(projection.getOperands (),
53
70
projection.getResults ())) {
71
+ newShardings.push_back (tensorFactorShardings.factorIndexToSharding );
54
72
for (const auto & [factorIndex, factorSharding] :
55
73
tensorFactorShardings.factorIndexToSharding ) {
56
- truncateAxesByRemovingConflicts (
57
- result[factorIndex],
74
+ if (!shouldUpdate (factorSharding.axisRefs , axesPerFactor[factorIndex])) {
75
+ continue ;
76
+ }
77
+ SmallVector<AxisRefAttr> newAxes = truncateAxesByRemovingConflicts (
78
+ axesPerFactor[factorIndex],
58
79
[&, factorIndex = factorIndex, &factorSharding = factorSharding](
59
80
AxisRefAttr axisRef, int64_t shardedSize) {
60
81
return compatiblePrefixNoConflictsWithinFactor (
61
82
axisRef, tensorFactorShardings.replicatedAxes , factorSharding,
62
83
shardedSize, factorSizes[factorIndex]);
63
84
},
64
85
mesh, conservativePropagation);
86
+ newShardings.back ()[factorIndex].axisRefs = newAxes;
65
87
}
66
88
}
67
89
68
- // Removes the conflicts across factors, where this strategy and
69
- // `BasicFactorPropagation` diverge.
70
- //
71
- // With `BasicFactorPropagation`, the compatible axes of a factor Fi cannot
72
- // overlap with the existing sharding axes or the overflow axes related to all
73
- // other factors. This criterion is considered for all tensors, no matter if
74
- // Fi is mapped to the tensor or not. The table below shows the criterion:
75
- //
76
- // existing sharding axes & overflow axes new sharding axes
77
- // factor in tensor remove overlap -
78
- // factor not in tensor remove overlap -
79
- //
80
- // On the contrary, `AggressiveFactorPropagation` has the following criterion:
81
- //
82
- // existing sharding axes & overflow axes new sharding axes
83
- // factor in tensor remove overlap remove overlap
84
- // factor not in tensor - -
85
- //
86
- // There are two differences:
87
- //
88
- // 1. `BasicFactorPropagation` removes the overlap between the compatible axes
89
- // of a factor Fi with the existing sharding axes and overflow axes in a
90
- // tensor Tj even if Fi is not in Tj. `AggressiveFactorPropagation` does not
91
- // remove this overlap if Fi is not in Tj. `BasicFactorPropagation` is too
92
- // strict, since we cannot propagate sharding axes to Tj along Fi.
93
- //
94
- // `AggressiveFactorPropagation` cannot handle the following case if we only
95
- // have difference #1. `-` means that the factor is not mapped to the tensor.
96
- // After removing conflicts within factors, we will propagate "x" to T2 along
97
- // F0 and F1 at the same time, which induces a conflict. To resolve this
98
- // conflict, we have difference #2.
99
- //
100
- // F0 F1
101
- // T0 "x" -
102
- // T1 - "x"
103
- // T2 ? ?
104
- //
105
- // 2. `AggressiveFactorPropagation` removes the overlap between compatible
106
- // axes of a factor Fi with the potential new sharding axes of other factors
107
- // in Tj if Fi is in Tj. Thus, it is safe to propagate the axes to Tj along Fi
108
- // without conflicts with other factors. In the example, we will not propagate
109
- // "x" along F0 or F1 since their potential new sharding axes overlap.
110
- //
111
- // The potential new sharding axes are saved in `resultSnapshot`. It is a hard
112
- // copy since we need to handle the following case.
113
- //
114
- // F0 F1 F2
115
- // T0 "x" - -
116
- // T1 - "x" -
117
- // T2 - - "x"
118
- // T3 ? ? ?
119
- //
120
- // The `result` and `resultSnapshot` is [["x"], ["x"], ["x"]] before removing
121
- // conflicts across factors. After removing conflicts between F0/F1 and other
122
- // factors, `result` is [[], [], ["x"]]. When we remove conflicts between F2
123
- // and other factors, if we use `result` as the potential new sharding axes,
124
- // we will not remove "x" for F2 because it is no longer present in 'result'
125
- // for F0 and F1. We have to use `resultSnapshot` to save the potential new
126
- // sharding axes and remove "x" for F2.
127
- const AxesPerFactor resultSnapshot = result;
128
- for (const TensorFactorShardings& tensorFactorSharding :
129
- llvm::concat<const TensorFactorShardings>(projection.getOperands (),
130
- projection.getResults ())) {
131
- for (const auto & [factorIndex, factorSharding] :
132
- tensorFactorSharding.factorIndexToSharding ) {
133
- truncateAxesByRemovingConflicts (
134
- result[factorIndex],
90
+ // Step 3. Remove conflicts (overlapping sharding axes) between factors in
91
+ // `newShardings`. Update (1) shardings in `projection` and (2) `result` based
92
+ // on new sharding axes.
93
+ for (const auto & [tensorIndex, newSharding] : llvm::enumerate (newShardings)) {
94
+ bool tensorUpdated = false ;
95
+ for (const auto & [factorIndex, factorSharding] : newSharding) {
96
+ SmallVector<AxisRefAttr> newAxes = truncateAxesByRemovingConflicts (
97
+ factorSharding.axisRefs ,
135
98
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t ) {
136
99
return compatiblePrefixNoConflictsAcrossFactors (
137
- axisRef, tensorFactorSharding.factorIndexToSharding ,
138
- factorIndex, resultSnapshot);
100
+ axisRef, newSharding, factorIndex);
139
101
},
140
102
mesh, conservativePropagation);
103
+ tensorUpdated |=
104
+ updateTensorSharding (projection, tensorIndex, factorIndex, newAxes);
105
+ }
106
+ if (tensorIndex < projection.getNumOperands ()) {
107
+ result.updateOperands [tensorIndex] = tensorUpdated;
108
+ } else {
109
+ result.updateResults [tensorIndex - projection.getNumOperands ()] =
110
+ tensorUpdated;
141
111
}
142
- }
143
-
144
- return result;
145
- }
146
-
147
- UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings (
148
- ShardingProjection& projection, PropagationDirection direction,
149
- ArrayRef<int64_t > factorSizes, MeshAttr mesh, Operation* op,
150
- bool conservativePropagation) const {
151
- UpdateTensorShardings result{
152
- .updateOperands = BitVector (projection.getNumOperands ()),
153
- .updateResults = BitVector (projection.getNumResults ())};
154
-
155
- // We get the compatible major sharding axes for all factors.
156
- AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors (
157
- projection, direction, factorSizes, mesh, op, conservativePropagation);
158
-
159
- for (auto [factorIndex, axesToPropagate] : llvm::enumerate (axesPerFactor)) {
160
- // Update all shardings along this factor if possible.
161
- auto [updateOperandForFactor, updateResultForFactor] =
162
- projection.updateSharding (factorIndex, axesToPropagate);
163
-
164
- result.updateOperands |= updateOperandForFactor;
165
- result.updateResults |= updateResultForFactor;
166
112
}
167
113
168
114
return result;
0 commit comments