diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc index d7fa040c..64ea8708 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc @@ -863,23 +863,36 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, .addPointwise(getTensorShape(select.getResult())) .build(); }) - .Case( - [conservativePropagation](stablehlo::SliceOp slice) { - // If `conservativePropagation` is false, we propagate through - // sliced dimensions, even though that would require communication. - // - // This is different from `DynamicSliceOp`, where we don't - // propagate through sliced dimensions regardless of - // `conservativePropagation`, and the reason is that for `SliceOp` - // the start indices are static, so we know how to shift the data - // to keep the sliced dimension sharded. - return OpShardingRuleBuilder(slice) - .addPointwiseIfDimSizesMatch( - getTensorShape(slice.getOperand()), - getTensorShape(slice.getResult()), - /*alwaysAddFactor=*/!conservativePropagation) - .build(); - }) + .Case([conservativePropagation]( + stablehlo::SliceOp slice) { + // If `conservativePropagation` is false, we propagate through + // sliced dimensions, even though that would require communication. + // + // There is an exception. If the input dimension size is larger than 1 + // and the output dimension size is 1, we do not propagate through this + // sliced dimension. + // + // This is different from `DynamicSliceOp`, where we don't + // propagate through sliced dimensions regardless of + // `conservativePropagation`, and the reason is that for `SliceOp` + // the start indices are static, so we know how to shift the data + // to keep the sliced dimension sharded. + ArrayRef inShape = getTensorShape(slice.getOperand()); + ArrayRef outShape = getTensorShape(slice.getResult()); + auto onMismatchFn = [&](int64_t dim, OpShardingRuleBuilder& builder) { + if (conservativePropagation) { + return; + } + if (inShape[dim] != 1 && outShape[dim] == 1) { + return; + } + builder.addFactor(dim, inShape[dim]); + }; + return OpShardingRuleBuilder(slice) + .addPointwiseIfDimSizesMatch( + inShape, outShape, /*alwaysAddFactor=*/false, onMismatchFn) + .build(); + }) .Case([](stablehlo::SortOp sort) { // If the input is sharded along the sort dimension, and any of the // non-sort dimensions has size >1, the partitioner will add an diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc index 20c914c0..ddace2ad 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc @@ -360,9 +360,9 @@ TEST_F(ShardingProjectionBuildTest, FactorWithSmallerSizeThanDimOverflows) { sdy.mesh @mesh = <["a"=2, "b"=4, "c"=2, "d"=4, "e"=2]> func.func @main(%arg0: tensor<32x4x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"c", ?}, {"b", "d":(2)2, "e"}]>}) - -> tensor<32x1x2xf32> { - %0 = stablehlo.slice %arg0 [0:32, 1:2, 4:6] : (tensor<32x4x16xf32>) -> tensor<32x1x2xf32> - return %0 : tensor<32x1x2xf32> + -> tensor<32x2x2xf32> { + %0 = stablehlo.slice %arg0 [0:32, 0:2, 4:6] : (tensor<32x4x16xf32>) -> tensor<32x2x2xf32> + return %0 : tensor<32x2x2xf32> })mlir"; OwningOpRef module = diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir index 87049ba7..1a8bbfb0 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir @@ -452,10 +452,10 @@ func.func @pad(%arg0: tensor<28x28x16xf32>, %arg1: tensor) -> tensor<30x26x } // CHECK-LABEL: func @slice -func.func @slice(%arg0: tensor<32x4x8xf32>) -> tensor<32x1x2xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, j, k]) {i=32, j=4, k=8}> - %0 = stablehlo.slice %arg0 [0:32, 1:2, 4:8:2] : (tensor<32x4x8xf32>) -> tensor<32x1x2xf32> - return %0 : tensor<32x1x2xf32> +func.func @slice(%arg0: tensor<32x4x8x1xf32>) -> tensor<32x1x2x1xf32> { + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, l, j, k])->([i, m, j, k]) {i=32, j=8, k=1, l=1, m=1}> + %0 = stablehlo.slice %arg0 [0:32, 1:2, 4:8:2, 0:1] : (tensor<32x4x8x1xf32>) -> tensor<32x1x2x1xf32> + return %0 : tensor<32x1x2x1xf32> } // Sort is currently treated as a pointwise op, and we add a factor for the sort