From 954f170e68eeb4e852f7c6e90f39a5eb273f802b Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 12 Dec 2024 21:45:51 -0800 Subject: [PATCH] #sdy. Fix sharding rule for SliceOp. Previously, we added factors for all dimensions even though the dimension size mismatches. This cl adds an exception. If the input dimension size is larger than 1 and the output dimension size is 1, we do not propagate through this dimension. PiperOrigin-RevId: 705743468 --- .../propagation/op_sharding_rule_registry.cc | 47 ++++++++++++------- .../propagation/sharding_projection_test.cc | 6 +-- .../test/op_sharding_rule_registry.mlir | 8 ++-- 3 files changed, 37 insertions(+), 24 deletions(-) diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc index d7fa040c..64ea8708 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc @@ -863,23 +863,36 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, .addPointwise(getTensorShape(select.getResult())) .build(); }) - .Case( - [conservativePropagation](stablehlo::SliceOp slice) { - // If `conservativePropagation` is false, we propagate through - // sliced dimensions, even though that would require communication. - // - // This is different from `DynamicSliceOp`, where we don't - // propagate through sliced dimensions regardless of - // `conservativePropagation`, and the reason is that for `SliceOp` - // the start indices are static, so we know how to shift the data - // to keep the sliced dimension sharded. - return OpShardingRuleBuilder(slice) - .addPointwiseIfDimSizesMatch( - getTensorShape(slice.getOperand()), - getTensorShape(slice.getResult()), - /*alwaysAddFactor=*/!conservativePropagation) - .build(); - }) + .Case([conservativePropagation]( + stablehlo::SliceOp slice) { + // If `conservativePropagation` is false, we propagate through + // sliced dimensions, even though that would require communication. + // + // There is an exception. If the input dimension size is larger than 1 + // and the output dimension size is 1, we do not propagate through this + // sliced dimension. + // + // This is different from `DynamicSliceOp`, where we don't + // propagate through sliced dimensions regardless of + // `conservativePropagation`, and the reason is that for `SliceOp` + // the start indices are static, so we know how to shift the data + // to keep the sliced dimension sharded. + ArrayRef inShape = getTensorShape(slice.getOperand()); + ArrayRef outShape = getTensorShape(slice.getResult()); + auto onMismatchFn = [&](int64_t dim, OpShardingRuleBuilder& builder) { + if (conservativePropagation) { + return; + } + if (inShape[dim] != 1 && outShape[dim] == 1) { + return; + } + builder.addFactor(dim, inShape[dim]); + }; + return OpShardingRuleBuilder(slice) + .addPointwiseIfDimSizesMatch( + inShape, outShape, /*alwaysAddFactor=*/false, onMismatchFn) + .build(); + }) .Case([](stablehlo::SortOp sort) { // If the input is sharded along the sort dimension, and any of the // non-sort dimensions has size >1, the partitioner will add an diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc index 20c914c0..ddace2ad 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc @@ -360,9 +360,9 @@ TEST_F(ShardingProjectionBuildTest, FactorWithSmallerSizeThanDimOverflows) { sdy.mesh @mesh = <["a"=2, "b"=4, "c"=2, "d"=4, "e"=2]> func.func @main(%arg0: tensor<32x4x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}, {"b", "d":(2)2, "e"}]>}) - -> tensor<32x1x2xf32> { - %0 = stablehlo.slice %arg0 [0:32, 1:2, 4:6] : (tensor<32x4x16xf32>) -> tensor<32x1x2xf32> - return %0 : tensor<32x1x2xf32> + -> tensor<32x2x2xf32> { + %0 = stablehlo.slice %arg0 [0:32, 0:2, 4:6] : (tensor<32x4x16xf32>) -> tensor<32x2x2xf32> + return %0 : tensor<32x2x2xf32> })mlir"; OwningOpRef module = diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir index 87049ba7..1a8bbfb0 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir @@ -452,10 +452,10 @@ func.func @pad(%arg0: tensor<28x28x16xf32>, %arg1: tensor) -> tensor<30x26x } // CHECK-LABEL: func @slice -func.func @slice(%arg0: tensor<32x4x8xf32>) -> tensor<32x1x2xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, j, k]) {i=32, j=4, k=8}> - %0 = stablehlo.slice %arg0 [0:32, 1:2, 4:8:2] : (tensor<32x4x8xf32>) -> tensor<32x1x2xf32> - return %0 : tensor<32x1x2xf32> +func.func @slice(%arg0: tensor<32x4x8x1xf32>) -> tensor<32x1x2x1xf32> { + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, l, j, k])->([i, m, j, k]) {i=32, j=8, k=1, l=1, m=1}> + %0 = stablehlo.slice %arg0 [0:32, 1:2, 4:8:2, 0:1] : (tensor<32x4x8x1xf32>) -> tensor<32x1x2x1xf32> + return %0 : tensor<32x1x2x1xf32> } // Sort is currently treated as a pointwise op, and we add a factor for the sort