Skip to content

Commit 538f262

Browse files
tomnatan30copybara-github
authored andcommitted
Don't allow sideways propagation between operands/results when propagation direction is forward/backwards.
The rational is that a forward/backwards propagation isn't meant to update the sharding of other operands/results. PiperOrigin-RevId: 734104394
1 parent 201c8ea commit 538f262

8 files changed

+295
-79
lines changed

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc

+74-63
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,36 @@ bool isStrictPrefixOfFactorSharding(
9696
return false;
9797
}
9898

99+
// Only propagate axes to operands that are also present in at least one result.
100+
//
101+
// We want to avoid the following situation which can happen when a
102+
// `sharding_constraint` is added onto the operand during Shardy import:
103+
// ```
104+
// %arg0: [{"a", ?}]
105+
// %arg1: [{?}]
106+
// %0 = add %arg0, %arg1 : [{}]
107+
// ```
108+
// We don't want to do an all-gather on both %arg0 and %arg1 due to "a"
109+
// propagating sideways. Instead with the code below, since "a" can't
110+
// propagate to `%0`, we will only do an all-gather on %arg0.
111+
//
112+
// TODO(b/396642774): Long term we should undo this and allow sideways
113+
// propagation, but have our explicit reshard pass make sure the result is
114+
// all-gathered instead of both operands.
115+
void cancelSidewaysPropagationForElementwise(ShardingProjection& projection,
116+
int64_t factorIndex,
117+
SmallVector<AxisRefAttr>& newAxes,
118+
Operation* op) {
119+
if (!op || !isElementwise(op)) {
120+
return;
121+
}
122+
for (const TensorFactorShardings& result : projection.getResults()) {
123+
if (isStrictPrefixOfFactorSharding(result, factorIndex, newAxes)) {
124+
newAxes = result.factorIndexToSharding.at(factorIndex).axisRefs;
125+
}
126+
}
127+
}
128+
99129
} // namespace
100130

101131
SmallVector<AxisRefAttr>
@@ -114,9 +144,7 @@ AggressiveFactorPropagation::getPropagatedFactorSharding(
114144
// Resolve conflicts within a factor.
115145
truncateAxesByRemovingConflicts(
116146
newAxes,
117-
[&, factorIndex = factorIndex,
118-
&tensorFactorShardings = tensorFactorShardings](
119-
AxisRefAttr axisRef, int64_t prevShardedSize) {
147+
[&](AxisRefAttr axisRef, int64_t prevShardedSize) {
120148
return compatiblePrefixNoConflictsWithinFactor(
121149
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
122150
prevShardedSize, factorSizes[factorIndex], mesh);
@@ -133,7 +161,7 @@ AggressiveFactorPropagation::getPropagatedFactorSharding(
133161
// checking for conflicts w.r.t. the updated state of this tensor.
134162
truncateAxesByRemovingConflicts(
135163
newAxes,
136-
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
164+
[&](AxisRefAttr axisRef, int64_t) {
137165
return compatiblePrefixNoConflictsAcrossFactors(
138166
axisRef, factorIndexToSharding, factorIndex);
139167
},
@@ -182,71 +210,54 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
182210
factorToSourceTensor[j].index, j);
183211
});
184212

185-
for (const auto& [tensorIndex, tensorFactorShardings] :
186-
llvm::enumerate(projection.getResults())) {
187-
const FactorIndexToSharding& factorIndexToSharding =
188-
tensorFactorShardings.factorIndexToSharding;
189-
190-
// Propagate the axes got in Step 1, resolving conflicts between factors by
191-
// following the order of preference in `sortedFactorIndices`.
192-
bool tensorUpdated = false;
193-
for (int64_t factorIndex : sortedFactorIndices) {
194-
SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding(
195-
factorIndex, tensorFactorShardings, factorIndexToSharding,
196-
axesPerFactor, mesh, conservativePropagation, factorSizes);
197-
if (newAxes.empty()) {
198-
continue;
213+
// Propagate the axes got in Step 1, resolving conflicts between factors by
214+
// following the order of preference in `sortedFactorIndices`.
215+
for (int64_t factorIndex : sortedFactorIndices) {
216+
PropagationDirection direction = directionAlongFactor(factorIndex);
217+
// 1. Propagate to results.
218+
//
219+
// We don't propagate sideways between results in backwards propagation,
220+
// so the sharding of the results along this factor shouldn't change.
221+
if (direction != PropagationDirection::BACKWARD) {
222+
for (const auto& [tensorIndex, tensorFactorShardings] :
223+
llvm::enumerate(projection.getResults())) {
224+
SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding(
225+
factorIndex, tensorFactorShardings,
226+
tensorFactorShardings.factorIndexToSharding, axesPerFactor, mesh,
227+
conservativePropagation, factorSizes);
228+
if (newAxes.empty()) {
229+
continue;
230+
}
231+
if (expandTensorSharding(projection,
232+
tensorIndex + projection.getNumOperands(),
233+
factorIndex, newAxes)) {
234+
result.updateResults.set(tensorIndex);
235+
}
199236
}
200-
tensorUpdated |= expandTensorSharding(
201-
projection, tensorIndex + projection.getNumOperands(), factorIndex,
202-
newAxes);
203237
}
204-
result.updateResults[tensorIndex] = tensorUpdated;
205-
}
206-
207-
for (const auto& [tensorIndex, tensorFactorShardings] :
208-
llvm::enumerate(projection.getOperands())) {
209-
const FactorIndexToSharding& factorIndexToSharding =
210-
tensorFactorShardings.factorIndexToSharding;
211-
212-
// Propagate the axes got in Step 1, resolving conflicts between factors by
213-
// following the order of preference in `sortedFactorIndices`.
214-
bool tensorUpdated = false;
215-
for (int64_t factorIndex : sortedFactorIndices) {
216-
SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding(
217-
factorIndex, tensorFactorShardings, factorIndexToSharding,
218-
axesPerFactor, mesh, conservativePropagation, factorSizes);
219-
if (newAxes.empty()) {
220-
continue;
221-
}
222238

223-
// Only propagate sideways through operands the factors that are also
224-
// used in at least one result We want to avoid the following situation
225-
// which can happen when a `sharding_constraint` is added onto the operand
226-
// during Shardy import:
227-
// ```
228-
// %arg0: [{"a", ?}]
229-
// %arg1: [{?}]
230-
// %0 = add %arg0, %arg1 : [{}]
231-
// ```
232-
// We don't want to do an all-gather on both %arg0 and %arg1 due to "a"
233-
// propagating sideways. Instead with the code below, since "a" can't
234-
// propagate to `%0`, we will only do an all-gather on %arg0.
235-
//
236-
// TODO(b/396642774): Long term we should undo this and allow sideways
237-
// propagation, but have our explicit reshard pass make sure the result is
238-
// all-gathered instead of both operands.
239-
if (op && isElementwise(op)) {
240-
for (const TensorFactorShardings& result : projection.getResults()) {
241-
if (isStrictPrefixOfFactorSharding(result, factorIndex, newAxes)) {
242-
newAxes = result.factorIndexToSharding.at(factorIndex).axisRefs;
243-
}
239+
// 2. Propagate to operands.
240+
//
241+
// We don't propagate sideways between operands in forward propagation,
242+
// so the sharding of the operands along this factor shouldn't change.
243+
if (direction != PropagationDirection::FORWARD) {
244+
for (const auto& [tensorIndex, tensorFactorShardings] :
245+
llvm::enumerate(projection.getOperands())) {
246+
SmallVector<AxisRefAttr> newAxes = getPropagatedFactorSharding(
247+
factorIndex, tensorFactorShardings,
248+
tensorFactorShardings.factorIndexToSharding, axesPerFactor, mesh,
249+
conservativePropagation, factorSizes);
250+
if (newAxes.empty()) {
251+
continue;
252+
}
253+
cancelSidewaysPropagationForElementwise(projection, factorIndex,
254+
newAxes, op);
255+
if (expandTensorSharding(projection, tensorIndex, factorIndex,
256+
newAxes)) {
257+
result.updateOperands.set(tensorIndex);
244258
}
245259
}
246-
tensorUpdated |=
247-
expandTensorSharding(projection, tensorIndex, factorIndex, newAxes);
248260
}
249-
result.updateOperands[tensorIndex] = tensorUpdated;
250261
}
251262
return result;
252263
}

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc

+97
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,103 @@ TEST_F(AggressiveFactorPropagationTest, PropagateAlongSpecificFactor) {
397397
propagateAlongFactor(propagateAnything(), propagateAlongFactor0Expected);
398398
}
399399

400+
// NOTE: This test is the same as the one in basic_factor_propagation_test.cc,
401+
// and verifies that we get the expected behavior in both strategies.
402+
TEST_F(AggressiveFactorPropagationTest,
403+
DifferentDirectionsForDifferentFactors) {
404+
ShardingProjection projection(
405+
/*operands=*/
406+
{{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
407+
{1, {.axisRefs = {createAxis("b")}}},
408+
{2, {.axisRefs = {createAxis("c")}}},
409+
{3, {.axisRefs = {createAxis("d")}}},
410+
{4, {.axisRefs = {}}},
411+
{5, {.axisRefs = {}}},
412+
{6, {.axisRefs = {}}},
413+
{7, {.axisRefs = {}}}}},
414+
{.factorIndexToSharding = {{0, {.axisRefs = {}}},
415+
{1, {.axisRefs = {}}},
416+
{2, {.axisRefs = {}}},
417+
{3, {.axisRefs = {}}},
418+
{4, {.axisRefs = {}}},
419+
{5, {.axisRefs = {}}},
420+
{6, {.axisRefs = {}}},
421+
{7, {.axisRefs = {}}}}}},
422+
/*results=*/
423+
{{.factorIndexToSharding = {{0, {.axisRefs = {}}},
424+
{1, {.axisRefs = {}}},
425+
{2, {.axisRefs = {}}},
426+
{3, {.axisRefs = {}}},
427+
{4, {.axisRefs = {createAxis("e")}}},
428+
{5, {.axisRefs = {createAxis("f")}}},
429+
{6, {.axisRefs = {}}},
430+
{7, {.axisRefs = {createAxis("h")}}}}},
431+
{.factorIndexToSharding = {{0, {.axisRefs = {}}},
432+
{1, {.axisRefs = {}}},
433+
{2, {.axisRefs = {}}},
434+
{3, {.axisRefs = {}}},
435+
{4, {.axisRefs = {}}},
436+
{5, {.axisRefs = {}}},
437+
{6, {.axisRefs = {createAxis("g")}}},
438+
{7, {.axisRefs = {}}}}}});
439+
440+
PropagationDirectionAlongFactor directionAlongFactor =
441+
[](int64_t factorIndex) {
442+
if (factorIndex == 0 || factorIndex == 4) {
443+
return PropagationDirection::BOTH;
444+
}
445+
if (factorIndex == 1 || factorIndex == 5) {
446+
return PropagationDirection::FORWARD;
447+
}
448+
if (factorIndex == 2 || factorIndex == 6) {
449+
return PropagationDirection::BACKWARD;
450+
}
451+
return PropagationDirection::NONE;
452+
};
453+
454+
ShardingProjection projectionExpected(
455+
/*operands=*/
456+
{{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
457+
{1, {.axisRefs = {createAxis("b")}}},
458+
{2, {.axisRefs = {createAxis("c")}}},
459+
{3, {.axisRefs = {createAxis("d")}}},
460+
{4, {.axisRefs = {createAxis("e")}}},
461+
{5, {.axisRefs = {}}},
462+
{6, {.axisRefs = {createAxis("g")}}},
463+
{7, {.axisRefs = {}}}}},
464+
{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
465+
{1, {.axisRefs = {}}},
466+
{2, {.axisRefs = {}}},
467+
{3, {.axisRefs = {}}},
468+
{4, {.axisRefs = {createAxis("e")}}},
469+
{5, {.axisRefs = {}}},
470+
{6, {.axisRefs = {createAxis("g")}}},
471+
{7, {.axisRefs = {}}}}}},
472+
/*results=*/
473+
{{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
474+
{1, {.axisRefs = {createAxis("b")}}},
475+
{2, {.axisRefs = {}}},
476+
{3, {.axisRefs = {}}},
477+
{4, {.axisRefs = {createAxis("e")}}},
478+
{5, {.axisRefs = {createAxis("f")}}},
479+
{6, {.axisRefs = {}}},
480+
{7, {.axisRefs = {createAxis("h")}}}}},
481+
{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
482+
{1, {.axisRefs = {createAxis("b")}}},
483+
{2, {.axisRefs = {}}},
484+
{3, {.axisRefs = {}}},
485+
{4, {.axisRefs = {createAxis("e")}}},
486+
{5, {.axisRefs = {}}},
487+
{6, {.axisRefs = {createAxis("g")}}},
488+
{7, {.axisRefs = {}}}}}});
489+
490+
auto [updateOperands, updateResults] =
491+
propagateFactorShardings(projection, 8, directionAlongFactor);
492+
EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(0, 1));
493+
EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0, 1));
494+
EXPECT_EQ(projection, projectionExpected);
495+
}
496+
400497
// NOLINTEND(clang-diagnostic-pre-c++20-compat-pedantic)
401498

402499
} // namespace

shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -398,16 +398,17 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
398398

399399
// We propagate each factor separately.
400400
for (auto [factorIndex, factorSize] : llvm::enumerate(factorSizes)) {
401+
PropagationDirection direction = directionAlongFactor(factorIndex);
401402
// For each factor, find the compatible major sharding axes that can shard
402403
// that factor for all tensors, those are the axes we will propagate to
403404
// tensors that aren't already sharded.
404405
SmallVector<AxisRefAttr> axesToPropagate = getCompatibleMajorShardingAxes(
405-
projection, factorIndex, directionAlongFactor(factorIndex), factorSize,
406-
mesh, op, conservativePropagation);
406+
projection, factorIndex, direction, factorSize, mesh, op,
407+
conservativePropagation);
407408

408409
// Update all shardings along this factor if possible.
409410
auto [updateOperandForFactor, updateResultForFactor] =
410-
projection.expandSharding(factorIndex, axesToPropagate);
411+
projection.expandSharding(factorIndex, axesToPropagate, direction);
411412

412413
result.updateOperands |= updateOperandForFactor;
413414
result.updateResults |= updateResultForFactor;

shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -513,15 +513,15 @@ TEST_F(BasicFactorPropagationTest, DifferentDirectionsForDifferentFactors) {
513513
{3, {.axisRefs = {}}},
514514
{4, {.axisRefs = {createAxis("e")}}},
515515
{5, {.axisRefs = {createAxis("f")}}},
516-
{6, {.axisRefs = {createAxis("g")}}},
516+
{6, {.axisRefs = {}}},
517517
{7, {.axisRefs = {createAxis("h")}}}}},
518518
{.factorIndexToSharding = {{0, {.axisRefs = {}}},
519519
{1, {.axisRefs = {}}},
520520
{2, {.axisRefs = {}}},
521521
{3, {.axisRefs = {}}},
522522
{4, {.axisRefs = {}}},
523523
{5, {.axisRefs = {}}},
524-
{6, {.axisRefs = {}}},
524+
{6, {.axisRefs = {createAxis("g")}}},
525525
{7, {.axisRefs = {}}}}}});
526526

527527
PropagationDirectionAlongFactor directionAlongFactor =
@@ -549,7 +549,7 @@ TEST_F(BasicFactorPropagationTest, DifferentDirectionsForDifferentFactors) {
549549
{6, {.axisRefs = {createAxis("g")}}},
550550
{7, {.axisRefs = {}}}}},
551551
{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
552-
{1, {.axisRefs = {createAxis("b")}}},
552+
{1, {.axisRefs = {}}},
553553
{2, {.axisRefs = {}}},
554554
{3, {.axisRefs = {}}},
555555
{4, {.axisRefs = {createAxis("e")}}},
@@ -563,7 +563,7 @@ TEST_F(BasicFactorPropagationTest, DifferentDirectionsForDifferentFactors) {
563563
{3, {.axisRefs = {}}},
564564
{4, {.axisRefs = {createAxis("e")}}},
565565
{5, {.axisRefs = {createAxis("f")}}},
566-
{6, {.axisRefs = {createAxis("g")}}},
566+
{6, {.axisRefs = {}}},
567567
{7, {.axisRefs = {createAxis("h")}}}}},
568568
{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}},
569569
{1, {.axisRefs = {createAxis("b")}}},

shardy/dialect/sdy/transforms/propagation/sharding_projection.cc

+16-5
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,24 @@ TensorShardingAttr TensorFactorShardings::createTensorShardingAttr(
190190
}
191191

192192
UpdateTensorShardings ShardingProjection::expandSharding(
193-
int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes) {
193+
int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes,
194+
PropagationDirection direction) {
194195
UpdateTensorShardings result(getNumOperands(), getNumResults());
195-
for (auto [i, tensor] : llvm::enumerate(operands)) {
196-
result.updateOperands[i] = tensor.expandShardingAxes(factorIndex, newAxes);
196+
if (direction == PropagationDirection::NONE) {
197+
return result;
197198
}
198-
for (auto [i, tensor] : llvm::enumerate(results)) {
199-
result.updateResults[i] = tensor.expandShardingAxes(factorIndex, newAxes);
199+
// We don't propagate sideways between operands in forward propagation.
200+
if (direction != PropagationDirection::FORWARD) {
201+
for (auto [i, tensor] : llvm::enumerate(operands)) {
202+
result.updateOperands[i] =
203+
tensor.expandShardingAxes(factorIndex, newAxes);
204+
}
205+
}
206+
// We don't propagate sideways between results in backwards propagation.
207+
if (direction != PropagationDirection::BACKWARD) {
208+
for (auto [i, tensor] : llvm::enumerate(results)) {
209+
result.updateResults[i] = tensor.expandShardingAxes(factorIndex, newAxes);
210+
}
200211
}
201212
return result;
202213
}

shardy/dialect/sdy/transforms/propagation/sharding_projection.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,15 @@ class ShardingProjection {
208208
// Expands the shardings of all tensors that are associated with
209209
// `factorIndex` to be `newAxes` for that factor. Returns two BitVectors
210210
// indicating whether the operands and results have been expanded.
211+
//
212+
// If direction is:
213+
// - BOTH, both operands and results can be updated.
214+
// - FORWARD, only results can be updated.
215+
// - BACKWARD, only operands can be updated.
216+
// - NONE, no tensors are updated.
211217
UpdateTensorShardings expandSharding(int64_t factorIndex,
212-
ArrayRef<AxisRefAttr> newAxes);
218+
ArrayRef<AxisRefAttr> newAxes,
219+
PropagationDirection direction);
213220

214221
// Updates the shardings of all tensors that are associated with
215222
// `factorIndex` to be `newAxes` and `newOverflowAxes` for that factor. Keep

0 commit comments

Comments
 (0)