Skip to content

Commit 3e5e377

Browse files
ZixuanJiangcopybara-github
authored andcommitted
#sdy Remove unused op parameter.
Pure refactoring. PiperOrigin-RevId: 737780450
1 parent b6defad commit 3e5e377

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
156156
bool allElementsAreEmpty = true;
157157
for (int64_t i = 0; i < factorSizes.size(); ++i) {
158158
SmallVector<AxisRefAttr>& axes = axesPerFactor.emplace_back(
159-
getCompatibleMajorAxes(projection, i, directionAlongFactor(i), op));
159+
getCompatibleMajorAxes(projection, i, directionAlongFactor(i)));
160160
if (!axes.empty()) {
161161
allElementsAreEmpty = false;
162162
}

shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ std::pair<SmallVector<AxisRefAttr>, bool> getCompatibleMajorAxesInternal(
288288

289289
SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorAxes(
290290
const ShardingProjection& projection, int64_t factorIndex,
291-
PropagationDirection direction, Operation* op) const {
291+
PropagationDirection direction) const {
292292
if (direction == PropagationDirection::NONE) {
293293
return {};
294294
}
@@ -370,10 +370,10 @@ std::optional<AxisRefAttr> BasicFactorPropagation::compatiblePrefix(
370370
SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorShardingAxes(
371371
const ShardingProjection& projection, int64_t factorIndex,
372372
PropagationDirection direction, int64_t factorSize, MeshAttr mesh,
373-
Operation* op, bool conservativePropagation) const {
373+
bool conservativePropagation) const {
374374
// Finds the compatible major axes ignoring conflicts.
375375
SmallVector<AxisRefAttr> resultAxes =
376-
getCompatibleMajorAxes(projection, factorIndex, direction, op);
376+
getCompatibleMajorAxes(projection, factorIndex, direction);
377377

378378
// Removes the major-most axis that isn't compatible w.r.t. other factors or
379379
// the replicated axes, and all axes that are minor to it.
@@ -391,7 +391,7 @@ SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorShardingAxes(
391391
UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
392392
ShardingProjection& projection,
393393
PropagationDirectionAlongFactor directionAlongFactor,
394-
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
394+
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation*,
395395
bool conservativePropagation) const {
396396
UpdateTensorShardings result(projection.getNumOperands(),
397397
projection.getNumResults());
@@ -403,7 +403,7 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
403403
// tensors that aren't already sharded.
404404
SmallVector<AxisRefAttr> axesToPropagate = getCompatibleMajorShardingAxes(
405405
projection, factorIndex, directionAlongFactor(factorIndex), factorSize,
406-
mesh, op, conservativePropagation);
406+
mesh, conservativePropagation);
407407

408408
// Update all shardings along this factor if possible.
409409
auto [updateOperandForFactor, updateResultForFactor] =

shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class BasicFactorPropagation : public FactorPropagation {
4444
UpdateTensorShardings propagateFactorShardings(
4545
ShardingProjection& projection,
4646
PropagationDirectionAlongFactor directionAlongFactor,
47-
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
47+
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation*,
4848
bool conservativePropagation) const override;
4949

5050
protected:
@@ -88,7 +88,7 @@ class BasicFactorPropagation : public FactorPropagation {
8888
SmallVector<AxisRefAttr> getCompatibleMajorShardingAxes(
8989
const ShardingProjection& projection, int64_t factorIndex,
9090
PropagationDirection direction, int64_t factorSize, MeshAttr mesh,
91-
Operation* op, bool conservativePropagation) const;
91+
bool conservativePropagation) const;
9292

9393
// Finds the longest prefix of axes that shard the given factor, such that all
9494
// tensors either:
@@ -99,7 +99,7 @@ class BasicFactorPropagation : public FactorPropagation {
9999
// This method does not resolve conflicts across factors or replicated axes.
100100
SmallVector<AxisRefAttr> getCompatibleMajorAxes(
101101
const ShardingProjection& projection, int64_t factorIndex,
102-
PropagationDirection direction, Operation* op) const;
102+
PropagationDirection direction) const;
103103

104104
// Returns the largest prefix of `axisRef`, which does not overlap with
105105
// sharding axes and overflow axes for all other factors.

0 commit comments

Comments
 (0)