@@ -96,6 +96,36 @@ bool isStrictPrefixOfFactorSharding(
96
96
return false ;
97
97
}
98
98
99
+ // Only propagate axes to operands that are also present in at least one result.
100
+ //
101
+ // We want to avoid the following situation which can happen when a
102
+ // `sharding_constraint` is added onto the operand during Shardy import:
103
+ // ```
104
+ // %arg0: [{"a", ?}]
105
+ // %arg1: [{?}]
106
+ // %0 = add %arg0, %arg1 : [{}]
107
+ // ```
108
+ // We don't want to do an all-gather on both %arg0 and %arg1 due to "a"
109
+ // propagating sideways. Instead with the code below, since "a" can't
110
+ // propagate to `%0`, we will only do an all-gather on %arg0.
111
+ //
112
+ // TODO(b/396642774): Long term we should undo this and allow sideways
113
+ // propagation, but have our explicit reshard pass make sure the result is
114
+ // all-gathered instead of both operands.
115
+ void cancelSidewaysPropagationForElementwise (ShardingProjection& projection,
116
+ int64_t factorIndex,
117
+ SmallVector<AxisRefAttr>& newAxes,
118
+ Operation* op) {
119
+ if (!op || !isElementwise (op)) {
120
+ return ;
121
+ }
122
+ for (const TensorFactorShardings& result : projection.getResults ()) {
123
+ if (isStrictPrefixOfFactorSharding (result, factorIndex, newAxes)) {
124
+ newAxes = result.factorIndexToSharding .at (factorIndex).axisRefs ;
125
+ }
126
+ }
127
+ }
128
+
99
129
} // namespace
100
130
101
131
SmallVector<AxisRefAttr>
@@ -114,9 +144,7 @@ AggressiveFactorPropagation::getPropagatedFactorSharding(
114
144
// Resolve conflicts within a factor.
115
145
truncateAxesByRemovingConflicts (
116
146
newAxes,
117
- [&, factorIndex = factorIndex,
118
- &tensorFactorShardings = tensorFactorShardings](
119
- AxisRefAttr axisRef, int64_t prevShardedSize) {
147
+ [&](AxisRefAttr axisRef, int64_t prevShardedSize) {
120
148
return compatiblePrefixNoConflictsWithinFactor (
121
149
axisRef, tensorFactorShardings.replicatedAxes , factorSharding,
122
150
prevShardedSize, factorSizes[factorIndex], mesh);
@@ -133,7 +161,7 @@ AggressiveFactorPropagation::getPropagatedFactorSharding(
133
161
// checking for conflicts w.r.t. the updated state of this tensor.
134
162
truncateAxesByRemovingConflicts (
135
163
newAxes,
136
- [&, factorIndex = factorIndex ](AxisRefAttr axisRef, int64_t ) {
164
+ [&](AxisRefAttr axisRef, int64_t ) {
137
165
return compatiblePrefixNoConflictsAcrossFactors (
138
166
axisRef, factorIndexToSharding, factorIndex);
139
167
},
@@ -182,71 +210,54 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
182
210
factorToSourceTensor[j].index , j);
183
211
});
184
212
185
- for (const auto & [tensorIndex, tensorFactorShardings] :
186
- llvm::enumerate (projection.getResults ())) {
187
- const FactorIndexToSharding& factorIndexToSharding =
188
- tensorFactorShardings.factorIndexToSharding ;
189
-
190
- // Propagate the axes got in Step 1, resolving conflicts between factors by
191
- // following the order of preference in `sortedFactorIndices`.
192
- bool tensorUpdated = false ;
193
- for (int64_t factorIndex : sortedFactorIndices) {
194
- SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding (
195
- factorIndex, tensorFactorShardings, factorIndexToSharding,
196
- axesPerFactor, mesh, conservativePropagation, factorSizes);
197
- if (newAxes.empty ()) {
198
- continue ;
213
+ // Propagate the axes got in Step 1, resolving conflicts between factors by
214
+ // following the order of preference in `sortedFactorIndices`.
215
+ for (int64_t factorIndex : sortedFactorIndices) {
216
+ PropagationDirection direction = directionAlongFactor (factorIndex);
217
+ // 1. Propagate to results.
218
+ //
219
+ // We don't propagate sideways between results in backwards propagation,
220
+ // so the sharding of the results along this factor shouldn't change.
221
+ if (direction != PropagationDirection::BACKWARD) {
222
+ for (const auto & [tensorIndex, tensorFactorShardings] :
223
+ llvm::enumerate (projection.getResults ())) {
224
+ SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding (
225
+ factorIndex, tensorFactorShardings,
226
+ tensorFactorShardings.factorIndexToSharding , axesPerFactor, mesh,
227
+ conservativePropagation, factorSizes);
228
+ if (newAxes.empty ()) {
229
+ continue ;
230
+ }
231
+ if (expandTensorSharding (projection,
232
+ tensorIndex + projection.getNumOperands (),
233
+ factorIndex, newAxes)) {
234
+ result.updateResults .set (tensorIndex);
235
+ }
199
236
}
200
- tensorUpdated |= expandTensorSharding (
201
- projection, tensorIndex + projection.getNumOperands (), factorIndex,
202
- newAxes);
203
237
}
204
- result.updateResults [tensorIndex] = tensorUpdated;
205
- }
206
-
207
- for (const auto & [tensorIndex, tensorFactorShardings] :
208
- llvm::enumerate (projection.getOperands ())) {
209
- const FactorIndexToSharding& factorIndexToSharding =
210
- tensorFactorShardings.factorIndexToSharding ;
211
-
212
- // Propagate the axes got in Step 1, resolving conflicts between factors by
213
- // following the order of preference in `sortedFactorIndices`.
214
- bool tensorUpdated = false ;
215
- for (int64_t factorIndex : sortedFactorIndices) {
216
- SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding (
217
- factorIndex, tensorFactorShardings, factorIndexToSharding,
218
- axesPerFactor, mesh, conservativePropagation, factorSizes);
219
- if (newAxes.empty ()) {
220
- continue ;
221
- }
222
238
223
- // Only propagate sideways through operands the factors that are also
224
- // used in at least one result We want to avoid the following situation
225
- // which can happen when a `sharding_constraint` is added onto the operand
226
- // during Shardy import:
227
- // ```
228
- // %arg0: [{"a", ?}]
229
- // %arg1: [{?}]
230
- // %0 = add %arg0, %arg1 : [{}]
231
- // ```
232
- // We don't want to do an all-gather on both %arg0 and %arg1 due to "a"
233
- // propagating sideways. Instead with the code below, since "a" can't
234
- // propagate to `%0`, we will only do an all-gather on %arg0.
235
- //
236
- // TODO(b/396642774): Long term we should undo this and allow sideways
237
- // propagation, but have our explicit reshard pass make sure the result is
238
- // all-gathered instead of both operands.
239
- if (op && isElementwise (op)) {
240
- for (const TensorFactorShardings& result : projection.getResults ()) {
241
- if (isStrictPrefixOfFactorSharding (result, factorIndex, newAxes)) {
242
- newAxes = result.factorIndexToSharding .at (factorIndex).axisRefs ;
243
- }
239
+ // 2. Propagate to operands.
240
+ //
241
+ // We don't propagate sideways between operands in forward propagation,
242
+ // so the sharding of the operands along this factor shouldn't change.
243
+ if (direction != PropagationDirection::FORWARD) {
244
+ for (const auto & [tensorIndex, tensorFactorShardings] :
245
+ llvm::enumerate (projection.getOperands ())) {
246
+ SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding (
247
+ factorIndex, tensorFactorShardings,
248
+ tensorFactorShardings.factorIndexToSharding , axesPerFactor, mesh,
249
+ conservativePropagation, factorSizes);
250
+ if (newAxes.empty ()) {
251
+ continue ;
252
+ }
253
+ cancelSidewaysPropagationForElementwise (projection, factorIndex,
254
+ newAxes, op);
255
+ if (expandTensorSharding (projection, tensorIndex, factorIndex,
256
+ newAxes)) {
257
+ result.updateOperands .set (tensorIndex);
244
258
}
245
259
}
246
- tensorUpdated |=
247
- expandTensorSharding (projection, tensorIndex, factorIndex, newAxes);
248
260
}
249
- result.updateOperands [tensorIndex] = tensorUpdated;
250
261
}
251
262
return result;
252
263
}
0 commit comments