@@ -288,7 +288,7 @@ std::pair<SmallVector<AxisRefAttr>, bool> getCompatibleMajorAxesInternal(
288
288
289
289
SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorAxes (
290
290
const ShardingProjection& projection, int64_t factorIndex,
291
- PropagationDirection direction, Operation* op ) const {
291
+ PropagationDirection direction) const {
292
292
if (direction == PropagationDirection::NONE) {
293
293
return {};
294
294
}
@@ -370,10 +370,10 @@ std::optional<AxisRefAttr> BasicFactorPropagation::compatiblePrefix(
370
370
SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorShardingAxes (
371
371
const ShardingProjection& projection, int64_t factorIndex,
372
372
PropagationDirection direction, int64_t factorSize, MeshAttr mesh,
373
- Operation* op, bool conservativePropagation) const {
373
+ bool conservativePropagation) const {
374
374
// Finds the compatible major axes ignoring conflicts.
375
375
SmallVector<AxisRefAttr> resultAxes =
376
- getCompatibleMajorAxes (projection, factorIndex, direction, op );
376
+ getCompatibleMajorAxes (projection, factorIndex, direction);
377
377
378
378
// Removes the major-most axis that isn't compatible w.r.t. other factors or
379
379
// the replicated axes, and all axes that are minor to it.
@@ -391,7 +391,7 @@ SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorShardingAxes(
391
391
UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings (
392
392
ShardingProjection& projection,
393
393
PropagationDirectionAlongFactor directionAlongFactor,
394
- ArrayRef<int64_t > factorSizes, MeshAttr mesh, Operation* op ,
394
+ ArrayRef<int64_t > factorSizes, MeshAttr mesh, Operation*,
395
395
bool conservativePropagation) const {
396
396
UpdateTensorShardings result (projection.getNumOperands (),
397
397
projection.getNumResults ());
@@ -403,7 +403,7 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
403
403
// tensors that aren't already sharded.
404
404
SmallVector<AxisRefAttr> axesToPropagate = getCompatibleMajorShardingAxes (
405
405
projection, factorIndex, directionAlongFactor (factorIndex), factorSize,
406
- mesh, op, conservativePropagation);
406
+ mesh, conservativePropagation);
407
407
408
408
// Update all shardings along this factor if possible.
409
409
auto [updateOperandForFactor, updateResultForFactor] =
0 commit comments