@@ -55,23 +55,6 @@ bool isTranspose(stablehlo::Transpose transpose) {
55
55
llvm_unreachable (" unknown stablehlo::Transpose" );
56
56
}
57
57
58
- // If `addFactorForMismatchedSize` is true, we add a factor for all dimensions
59
- // with the corresponding size in `inType`, otherwise we only add a factor for
60
- // dimensions with the same input and output size, letting the builder add size
61
- // 1 factors for other dimensions.
62
- OpShardingRuleAttr createMismatchedDimSizeShardingRule (
63
- Operation* op, RankedTensorType inType, RankedTensorType outType,
64
- bool addFactorForMismatchedSize) {
65
- return OpShardingRuleBuilder (op)
66
- .addPointwiseIf (inType.getShape (),
67
- [&](int64_t dim) {
68
- return addFactorForMismatchedSize ||
69
- inType.getDimSize (dim) ==
70
- outType.getDimSize (dim);
71
- })
72
- .build ();
73
- }
74
-
75
58
// Returns a vector with `numInputs` copies of `inputDim`, followed by a single
76
59
// `indicesDim`, then `numInputs` copies of `updateDim`, which matches the order
77
60
// and quantity of scatter operands.
@@ -450,11 +433,30 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
450
433
//
451
434
// Operands: [operand, iota, init_val (scalar), init_arg (scalar)]
452
435
// Results: [values, indices]
453
- return createMismatchedDimSizeShardingRule (
454
- customCall,
455
- cast<RankedTensorType>(customCall.getOperand (0 ).getType ()),
456
- cast<RankedTensorType>(customCall.getResult (0 ).getType ()),
457
- /* addFactorForMismatchedSize=*/ false );
436
+ ArrayRef<int64_t > inputShape =
437
+ getTensorShape (customCall.getOperand (0 ));
438
+ ArrayRef<int64_t > resultShape =
439
+ getTensorShape (customCall.getResult (0 ));
440
+ int64_t numInputs = 2 , numResults = 2 ;
441
+ SmallVector<int64_t > operandDims (customCall->getNumOperands (),
442
+ kNullDim );
443
+ SmallVector<int64_t > resultDims (customCall->getNumResults (),
444
+ kNullDim );
445
+ return OpShardingRuleBuilder (customCall)
446
+ .addPointwiseIfDimSizesMatch (
447
+ inputShape, resultShape,
448
+ /* alwaysAddFactor=*/ false ,
449
+ /* onMismatchFn=*/
450
+ [&](int64_t dim, OpShardingRuleBuilder& builder) {
451
+ std::fill_n (operandDims.begin (), numInputs, dim);
452
+ resultDims.assign (numResults, kNullDim );
453
+ builder.addFactor (operandDims, resultDims, inputShape[dim]);
454
+ resultDims.assign (numResults, dim);
455
+ std::fill_n (operandDims.begin (), numInputs, kNullDim );
456
+ builder.addFactor (operandDims, resultDims,
457
+ resultShape[dim]);
458
+ })
459
+ .build ();
458
460
}
459
461
// TODO(b/327191011): output unregistered op stats instead.
460
462
unreachableFormatv (
@@ -540,30 +542,30 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
540
542
})
541
543
.Case <stablehlo::DynamicSliceOp>(
542
544
[](stablehlo::DynamicSliceOp dynamicSlice) {
543
- return createMismatchedDimSizeShardingRule (
544
- dynamicSlice, dynamicSlice.getOperand ().getType (),
545
- dynamicSlice.getType (),
546
- /* addFactorForMismatchedSize=*/ false );
545
+ return OpShardingRuleBuilder (dynamicSlice)
546
+ .addPointwiseIfDimSizesMatch (
547
+ getTensorShape (dynamicSlice.getOperand ()),
548
+ getTensorShape (dynamicSlice.getResult ()))
549
+ .build ();
547
550
})
548
551
.Case <stablehlo::DynamicUpdateSliceOp>(
549
552
[](stablehlo::DynamicUpdateSliceOp dynamicUpdateSlice) {
550
- OpShardingRuleBuilder builder (dynamicUpdateSlice);
551
553
ArrayRef<int64_t > operandShape =
552
- dynamicUpdateSlice.getOperand (). getType (). getShape ( );
554
+ getTensorShape ( dynamicUpdateSlice.getOperand ());
553
555
ArrayRef<int64_t > updateShape =
554
- dynamicUpdateSlice.getUpdate ().getType ().getShape ();
555
-
556
+ getTensorShape (dynamicUpdateSlice.getUpdate ());
556
557
SmallVector<int64_t > operandDims (
557
558
dynamicUpdateSlice->getNumOperands (), kNullDim );
558
- for (auto [dim, dimSizes] :
559
- llvm::enumerate (llvm::zip_equal (operandShape, updateShape))) {
560
- auto [operandDimSize, updateDimSize] = dimSizes;
561
- operandDims[0 ] = dim;
562
- operandDims[1 ] = operandDimSize == updateDimSize ? dim : kNullDim ;
563
- builder.addFactor (operandDims, dim, operandDimSize);
564
- }
565
-
566
- return builder.build ();
559
+ return OpShardingRuleBuilder (dynamicUpdateSlice)
560
+ .addPointwiseIfDimSizesMatch (
561
+ operandShape, updateShape,
562
+ /* alwaysAddFactor=*/ false ,
563
+ /* onMismatchFn=*/
564
+ [&](int64_t dim, OpShardingRuleBuilder& builder) {
565
+ operandDims[0 ] = dim;
566
+ builder.addFactor (operandDims, dim, operandShape[dim]);
567
+ })
568
+ .build ();
567
569
})
568
570
.Case <stablehlo::FftOp>([](stablehlo::FftOp fft) {
569
571
ArrayRef<int64_t > inShape = getTensorShape (fft.getOperand ());
@@ -602,9 +604,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
602
604
.Case <stablehlo::PadOp>([conservativePropagation](stablehlo::PadOp pad) {
603
605
// If `conservativePropagation` is false, we propagate through padded
604
606
// dimensions, even though that would require communication.
605
- return createMismatchedDimSizeShardingRule (
606
- pad, pad.getOperand ().getType (), pad.getType (),
607
- /* addFactorForMismatchedSize=*/ !conservativePropagation);
607
+ return OpShardingRuleBuilder (pad)
608
+ .addPointwiseIfDimSizesMatch (
609
+ getTensorShape (pad.getOperand ()),
610
+ getTensorShape (pad.getResult ()),
611
+ /* alwaysAddFactor=*/ !conservativePropagation)
612
+ .build ();
608
613
})
609
614
.Case <stablehlo::ReduceOp>([](stablehlo::ReduceOp reduce) {
610
615
OpShardingRuleBuilder builder (reduce);
@@ -657,12 +662,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
657
662
// In conservative mode, we only add a factor if the input and
658
663
// output dimension sizes are equal.
659
664
// TODO(tomnatan): should the reduced factor be compound?
660
- return createMismatchedDimSizeShardingRule (
661
- reduceWindow,
662
- cast<RankedTensorType> (reduceWindow.getResult ( 0 ). getType ()),
663
- cast<RankedTensorType>(
664
- reduceWindow. getInputs (). front (). getType ()),
665
- /* addFactorForMismatchedSize= */ !conservativePropagation );
665
+ return OpShardingRuleBuilder (reduceWindow)
666
+ . addPointwiseIfDimSizesMatch (
667
+ getTensorShape (reduceWindow.getInputs (). front ()),
668
+ getTensorShape (reduceWindow. getResult ( 0 )),
669
+ /* alwaysAddFactor= */ !conservativePropagation)
670
+ . build ( );
666
671
})
667
672
.Case <stablehlo::ReshapeOp>([](stablehlo::ReshapeOp reshape) {
668
673
RankedTensorType inType = reshape.getOperand ().getType ();
@@ -790,10 +795,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
790
795
// In conservative mode, we only add a factor if the input and
791
796
// source dimension sizes are equal.
792
797
// TODO(tomnatan): should the reduced factor be compound?
793
- return createMismatchedDimSizeShardingRule (
794
- selectAndScatter, selectAndScatter.getSource ().getType (),
795
- selectAndScatter.getOperand ().getType (),
796
- /* addFactorForMismatchedSize=*/ !conservativePropagation);
798
+ return OpShardingRuleBuilder (selectAndScatter)
799
+ .addPointwiseIfDimSizesMatch (
800
+ getTensorShape (selectAndScatter.getOperand ()),
801
+ getTensorShape (selectAndScatter.getSource ()),
802
+ /* alwaysAddFactor=*/ !conservativePropagation)
803
+ .build ();
797
804
})
798
805
.Case <stablehlo::SelectOp>([](stablehlo::SelectOp select ) {
799
806
// Case 1: `pred` is a scalar in which case it is broadcasted and must
@@ -815,9 +822,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
815
822
// `conservativePropagation`, and the reason is that for `SliceOp`
816
823
// the start indices are static, so we know how to shift the data
817
824
// to keep the sliced dimension sharded.
818
- return createMismatchedDimSizeShardingRule (
819
- slice, slice.getOperand ().getType (), slice.getType (),
820
- /* addFactorForMismatchedSize=*/ !conservativePropagation);
825
+ return OpShardingRuleBuilder (slice)
826
+ .addPointwiseIfDimSizesMatch (
827
+ getTensorShape (slice.getOperand ()),
828
+ getTensorShape (slice.getResult ()),
829
+ /* alwaysAddFactor=*/ !conservativePropagation)
830
+ .build ();
821
831
})
822
832
.Case <stablehlo::TransposeOp>([](stablehlo::TransposeOp transpose) {
823
833
OpShardingRuleBuilder builder (transpose);
0 commit comments