Skip to content

Commit 189519a

Browse files
tomnatan30copybara-github
authored andcommittedJul 19, 2024·
#sdy fix sharding rule of @PartialReduce/ApproxTopK custom call.
PiperOrigin-RevId: 653581265
1 parent 446f1b6 commit 189519a

File tree

4 files changed

+107
-73
lines changed

4 files changed

+107
-73
lines changed
 

‎shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc

+16
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,22 @@ OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIf(
184184
return *this;
185185
}
186186

187+
OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIfDimSizesMatch(
188+
ArrayRef<int64_t> inShape, ArrayRef<int64_t> outShape, bool alwaysAddFactor,
189+
std::function<void(int64_t dim, OpShardingRuleBuilder& builder)>
190+
onMismatchFn) {
191+
for (auto [dim, dimSizes] :
192+
llvm::enumerate(llvm::zip_equal(inShape, outShape))) {
193+
auto [inDimSize, outDimSize] = dimSizes;
194+
if (alwaysAddFactor || inDimSize == outDimSize) {
195+
addFactor(dim, inDimSize);
196+
} else {
197+
onMismatchFn(dim, *this);
198+
}
199+
}
200+
return *this;
201+
}
202+
187203
OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type,
188204
size_t numOperands,
189205
size_t numResults) {

‎shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,20 @@ class OpShardingRuleBuilder {
8888

8989
// Adds a pointwise factor for all dimensions that satisfy `pred` of all
9090
// operands/results that have rank at least 1.
91-
//
92-
// Adds a factor of size 1 to all other dimensions, which would block any
93-
// propagation along these dimensions.
9491
OpShardingRuleBuilder& addPointwiseIf(ArrayRef<int64_t> shape,
9592
std::function<bool(int64_t)> pred);
9693

94+
// Adds a pointwise factor for all dimensions, whose input and output sizes
95+
// match, of all operands/results that have rank at least 1.
96+
//
97+
// If `alwaysAddFactor` is true, we add a factor for all dimensions with the
98+
// corresponding size in `inType`, otherwise we only
99+
OpShardingRuleBuilder& addPointwiseIfDimSizesMatch(
100+
ArrayRef<int64_t> inShape, ArrayRef<int64_t> outShape,
101+
bool alwaysAddFactor = false,
102+
std::function<void(int64_t dim, OpShardingRuleBuilder& builder)>
103+
onMismatchFn = [](int64_t dim, OpShardingRuleBuilder& builder) {});
104+
97105
private:
98106
MLIRContext* context;
99107
SmallVector<int64_t> factorSizes;

‎shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc

+65-55
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,6 @@ bool isTranspose(stablehlo::Transpose transpose) {
5555
llvm_unreachable("unknown stablehlo::Transpose");
5656
}
5757

58-
// If `addFactorForMismatchedSize` is true, we add a factor for all dimensions
59-
// with the corresponding size in `inType`, otherwise we only add a factor for
60-
// dimensions with the same input and output size, letting the builder add size
61-
// 1 factors for other dimensions.
62-
OpShardingRuleAttr createMismatchedDimSizeShardingRule(
63-
Operation* op, RankedTensorType inType, RankedTensorType outType,
64-
bool addFactorForMismatchedSize) {
65-
return OpShardingRuleBuilder(op)
66-
.addPointwiseIf(inType.getShape(),
67-
[&](int64_t dim) {
68-
return addFactorForMismatchedSize ||
69-
inType.getDimSize(dim) ==
70-
outType.getDimSize(dim);
71-
})
72-
.build();
73-
}
74-
7558
// Returns a vector with `numInputs` copies of `inputDim`, followed by a single
7659
// `indicesDim`, then `numInputs` copies of `updateDim`, which matches the order
7760
// and quantity of scatter operands.
@@ -450,11 +433,30 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
450433
//
451434
// Operands: [operand, iota, init_val (scalar), init_arg (scalar)]
452435
// Results: [values, indices]
453-
return createMismatchedDimSizeShardingRule(
454-
customCall,
455-
cast<RankedTensorType>(customCall.getOperand(0).getType()),
456-
cast<RankedTensorType>(customCall.getResult(0).getType()),
457-
/*addFactorForMismatchedSize=*/false);
436+
ArrayRef<int64_t> inputShape =
437+
getTensorShape(customCall.getOperand(0));
438+
ArrayRef<int64_t> resultShape =
439+
getTensorShape(customCall.getResult(0));
440+
int64_t numInputs = 2, numResults = 2;
441+
SmallVector<int64_t> operandDims(customCall->getNumOperands(),
442+
kNullDim);
443+
SmallVector<int64_t> resultDims(customCall->getNumResults(),
444+
kNullDim);
445+
return OpShardingRuleBuilder(customCall)
446+
.addPointwiseIfDimSizesMatch(
447+
inputShape, resultShape,
448+
/*alwaysAddFactor=*/false,
449+
/*onMismatchFn=*/
450+
[&](int64_t dim, OpShardingRuleBuilder& builder) {
451+
std::fill_n(operandDims.begin(), numInputs, dim);
452+
resultDims.assign(numResults, kNullDim);
453+
builder.addFactor(operandDims, resultDims, inputShape[dim]);
454+
resultDims.assign(numResults, dim);
455+
std::fill_n(operandDims.begin(), numInputs, kNullDim);
456+
builder.addFactor(operandDims, resultDims,
457+
resultShape[dim]);
458+
})
459+
.build();
458460
}
459461
// TODO(b/327191011): output unregistered op stats instead.
460462
unreachableFormatv(
@@ -540,30 +542,30 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
540542
})
541543
.Case<stablehlo::DynamicSliceOp>(
542544
[](stablehlo::DynamicSliceOp dynamicSlice) {
543-
return createMismatchedDimSizeShardingRule(
544-
dynamicSlice, dynamicSlice.getOperand().getType(),
545-
dynamicSlice.getType(),
546-
/*addFactorForMismatchedSize=*/false);
545+
return OpShardingRuleBuilder(dynamicSlice)
546+
.addPointwiseIfDimSizesMatch(
547+
getTensorShape(dynamicSlice.getOperand()),
548+
getTensorShape(dynamicSlice.getResult()))
549+
.build();
547550
})
548551
.Case<stablehlo::DynamicUpdateSliceOp>(
549552
[](stablehlo::DynamicUpdateSliceOp dynamicUpdateSlice) {
550-
OpShardingRuleBuilder builder(dynamicUpdateSlice);
551553
ArrayRef<int64_t> operandShape =
552-
dynamicUpdateSlice.getOperand().getType().getShape();
554+
getTensorShape(dynamicUpdateSlice.getOperand());
553555
ArrayRef<int64_t> updateShape =
554-
dynamicUpdateSlice.getUpdate().getType().getShape();
555-
556+
getTensorShape(dynamicUpdateSlice.getUpdate());
556557
SmallVector<int64_t> operandDims(
557558
dynamicUpdateSlice->getNumOperands(), kNullDim);
558-
for (auto [dim, dimSizes] :
559-
llvm::enumerate(llvm::zip_equal(operandShape, updateShape))) {
560-
auto [operandDimSize, updateDimSize] = dimSizes;
561-
operandDims[0] = dim;
562-
operandDims[1] = operandDimSize == updateDimSize ? dim : kNullDim;
563-
builder.addFactor(operandDims, dim, operandDimSize);
564-
}
565-
566-
return builder.build();
559+
return OpShardingRuleBuilder(dynamicUpdateSlice)
560+
.addPointwiseIfDimSizesMatch(
561+
operandShape, updateShape,
562+
/*alwaysAddFactor=*/false,
563+
/*onMismatchFn=*/
564+
[&](int64_t dim, OpShardingRuleBuilder& builder) {
565+
operandDims[0] = dim;
566+
builder.addFactor(operandDims, dim, operandShape[dim]);
567+
})
568+
.build();
567569
})
568570
.Case<stablehlo::FftOp>([](stablehlo::FftOp fft) {
569571
ArrayRef<int64_t> inShape = getTensorShape(fft.getOperand());
@@ -602,9 +604,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
602604
.Case<stablehlo::PadOp>([conservativePropagation](stablehlo::PadOp pad) {
603605
// If `conservativePropagation` is false, we propagate through padded
604606
// dimensions, even though that would require communication.
605-
return createMismatchedDimSizeShardingRule(
606-
pad, pad.getOperand().getType(), pad.getType(),
607-
/*addFactorForMismatchedSize=*/!conservativePropagation);
607+
return OpShardingRuleBuilder(pad)
608+
.addPointwiseIfDimSizesMatch(
609+
getTensorShape(pad.getOperand()),
610+
getTensorShape(pad.getResult()),
611+
/*alwaysAddFactor=*/!conservativePropagation)
612+
.build();
608613
})
609614
.Case<stablehlo::ReduceOp>([](stablehlo::ReduceOp reduce) {
610615
OpShardingRuleBuilder builder(reduce);
@@ -657,12 +662,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
657662
// In conservative mode, we only add a factor if the input and
658663
// output dimension sizes are equal.
659664
// TODO(tomnatan): should the reduced factor be compound?
660-
return createMismatchedDimSizeShardingRule(
661-
reduceWindow,
662-
cast<RankedTensorType>(reduceWindow.getResult(0).getType()),
663-
cast<RankedTensorType>(
664-
reduceWindow.getInputs().front().getType()),
665-
/*addFactorForMismatchedSize=*/!conservativePropagation);
665+
return OpShardingRuleBuilder(reduceWindow)
666+
.addPointwiseIfDimSizesMatch(
667+
getTensorShape(reduceWindow.getInputs().front()),
668+
getTensorShape(reduceWindow.getResult(0)),
669+
/*alwaysAddFactor=*/!conservativePropagation)
670+
.build();
666671
})
667672
.Case<stablehlo::ReshapeOp>([](stablehlo::ReshapeOp reshape) {
668673
RankedTensorType inType = reshape.getOperand().getType();
@@ -790,10 +795,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
790795
// In conservative mode, we only add a factor if the input and
791796
// source dimension sizes are equal.
792797
// TODO(tomnatan): should the reduced factor be compound?
793-
return createMismatchedDimSizeShardingRule(
794-
selectAndScatter, selectAndScatter.getSource().getType(),
795-
selectAndScatter.getOperand().getType(),
796-
/*addFactorForMismatchedSize=*/!conservativePropagation);
798+
return OpShardingRuleBuilder(selectAndScatter)
799+
.addPointwiseIfDimSizesMatch(
800+
getTensorShape(selectAndScatter.getOperand()),
801+
getTensorShape(selectAndScatter.getSource()),
802+
/*alwaysAddFactor=*/!conservativePropagation)
803+
.build();
797804
})
798805
.Case<stablehlo::SelectOp>([](stablehlo::SelectOp select) {
799806
// Case 1: `pred` is a scalar in which case it is broadcasted and must
@@ -815,9 +822,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
815822
// `conservativePropagation`, and the reason is that for `SliceOp`
816823
// the start indices are static, so we know how to shift the data
817824
// to keep the sliced dimension sharded.
818-
return createMismatchedDimSizeShardingRule(
819-
slice, slice.getOperand().getType(), slice.getType(),
820-
/*addFactorForMismatchedSize=*/!conservativePropagation);
825+
return OpShardingRuleBuilder(slice)
826+
.addPointwiseIfDimSizesMatch(
827+
getTensorShape(slice.getOperand()),
828+
getTensorShape(slice.getResult()),
829+
/*alwaysAddFactor=*/!conservativePropagation)
830+
.build();
821831
})
822832
.Case<stablehlo::TransposeOp>([](stablehlo::TransposeOp transpose) {
823833
OpShardingRuleBuilder builder(transpose);

‎shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir

+15-15
Original file line numberDiff line numberDiff line change
@@ -194,43 +194,43 @@ func.func @custom_call_householder_product(%arg0: tensor<8x12x16xf32>, %arg1: te
194194
}
195195

196196
// CHECK-LABEL: func @custom_call_approx_topk
197-
func.func @custom_call_approx_topk(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>) {
198-
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}>
197+
func.func @custom_call_approx_topk(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) {
198+
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}>
199199
%0:2 = stablehlo.custom_call @ApproxTopK(%arg0, %arg1, %arg2, %arg3) {
200200
mhlo.backend_config = {
201201
aggregate_to_topk = true,
202202
recall_target = 0.9 : f32,
203203
reduction_dim = 1 : i64,
204204
reduction_input_size_override = -1 : i64,
205-
top_k = 1 : i64},
205+
top_k = 2 : i64},
206206
called_computations = [@top_k_gt_f32_comparator]} :
207-
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>)
208-
return %0#0, %0#1 : tensor<16x1xf32>, tensor<16x1xf32>
207+
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
208+
return %0#0, %0#1 : tensor<16x2xf32>, tensor<16x2xf32>
209209
}
210210

211211
// CHECK-LABEL: func @custom_call_partial_reduce
212-
func.func @custom_call_partial_reduce(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>) {
213-
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}>
212+
func.func @custom_call_partial_reduce(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) {
213+
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}>
214214
%0:2 = stablehlo.custom_call @PartialReduce(%arg0, %arg1, %arg2, %arg3) {
215215
mhlo.backend_config = {
216216
aggregate_to_topk = true,
217217
recall_target = 0.9 : f32,
218218
reduction_dim = 1 : i64,
219219
reduction_input_size_override = -1 : i64,
220-
top_k = 1 : i64},
220+
top_k = 2 : i64},
221221
called_computations = [@top_k_gt_f32_comparator]} :
222-
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>)
223-
return %0#0, %0#1 : tensor<16x1xf32>, tensor<16x1xf32>
222+
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
223+
return %0#0, %0#1 : tensor<16x2xf32>, tensor<16x2xf32>
224224
}
225225

226226
// CHECK-LABEL: func @custom_call_partial_reduce_string_backend_config
227-
func.func @custom_call_partial_reduce_string_backend_config(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>) {
228-
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}>
227+
func.func @custom_call_partial_reduce_string_backend_config(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) {
228+
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}>
229229
%0:2 = stablehlo.custom_call @PartialReduce(%arg0, %arg1, %arg2, %arg3) {
230-
backend_config = "{\22log2_reduction\22: 5, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 64, \22recall_target\22: 0.950000}",
230+
backend_config = "{\22log2_reduction\22: 5, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 2, \22recall_target\22: 0.950000}",
231231
called_computations = [@top_k_gt_f32_comparator]} :
232-
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>)
233-
return %0#0, %0#1 : tensor<16x1xf32>, tensor<16x1xf32>
232+
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
233+
return %0#0, %0#1 : tensor<16x2xf32>, tensor<16x2xf32>
234234
}
235235

236236
// CHECK-LABEL: func @unregisterd_custom_call_with_existing_rule

0 commit comments

Comments
 (0)