diff --git a/shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc b/shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc index 86cd2ed8..7ee7a1d2 100644 --- a/shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc +++ b/shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "third_party/openxla/shardy/src/shardy/dialect/sdy/ir/dialect.h" namespace mlir { namespace sdy { @@ -34,8 +35,8 @@ namespace sdy { namespace { bool shouldApply(Value input, Operation* op) { - if (getSharding(input)) { - // `input` already has a sharding. + if (getSharding(input) || input.getDefiningOp()) { + // `input` already has a sharding or is produced by a `DataFlowEdgeOp`. return false; } diff --git a/shardy/dialect/sdy/transforms/import/passes.td b/shardy/dialect/sdy/transforms/import/passes.td index 4a1c01c8..779b37c5 100644 --- a/shardy/dialect/sdy/transforms/import/passes.td +++ b/shardy/dialect/sdy/transforms/import/passes.td @@ -37,6 +37,8 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func: all of the following: * The input doesn't have an existing sharding. + * The input isn't produced by a `DataFlowEdgeOp`, which holds the sharding + of all targets of the edge. * The input is either only used by the `ShardingConstraintOp` or the latter doesn't have any uses (dangling) and the input doesn't have any other users of type `ShardingConstraintOp` or `ManualComputationOp`. diff --git a/shardy/dialect/sdy/transforms/import/test/apply_sharding_constraints.mlir b/shardy/dialect/sdy/transforms/import/test/apply_sharding_constraints.mlir index b20b9ddb..2f774043 100644 --- a/shardy/dialect/sdy/transforms/import/test/apply_sharding_constraints.mlir +++ b/shardy/dialect/sdy/transforms/import/test/apply_sharding_constraints.mlir @@ -18,6 +18,16 @@ func.func @input_has_one_use(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { return %1 : tensor<8x8xf32> } +// CHECK-LABEL: func @input_produced_by_data_flow_edge +func.func @input_produced_by_data_flow_edge(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK-NEXT: sdy.data_flow_edge %arg0 : tensor<8x8xf32> + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: sdy.sharding_constraint + %0 = sdy.data_flow_edge %arg0 : tensor<8x8xf32> + %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"b"}]> : tensor<8x8xf32> + return %1 : tensor<8x8xf32> +} + // CHECK-LABEL: func @input_is_func_input_with_one_use( // CHECK-SAMEL %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) func.func @input_is_func_input_with_one_use(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { diff --git a/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir b/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir index a2cce334..86fca6bd 100644 --- a/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir +++ b/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir @@ -1,5 +1,6 @@ -// RUN: sdy_opt %s -sdy-import-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s -split-input-file -sdy-import-pipeline 2>&1 | FileCheck %s +// Verifies that function `-inliner` pass is applied // CHECK-LABEL: func @main func.func @main(%arg0: tensor<16x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { // CHECK-NEXT: %[[CONST_0:.*]] = sdy.constant dense<1.000000e+00> @@ -19,3 +20,20 @@ func.func private @add_matmul_to_lhs(%arg0: tensor<8x16xf32>, %arg1: tensor<16x1 %1 = stablehlo.add %arg0, %0 : tensor<8x16xf32> return %1 : tensor<8x16xf32> } + +// ----- + +sdy.mesh @mesh = <"a"=2> + +// Verifies that `-apply-sharding-constraints` pass is applied after +// `-add-data_flow_edges` pass +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { + // CHECK-NEXT: %[[OPT_BARRIER:.*]] = stablehlo.optimization_barrier %arg0 + // CHECK-NEXT: sdy.data_flow_edge %[[OPT_BARRIER]] : tensor<32x96xf32> + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: sdy.sharding_constraint + %0 = stablehlo.optimization_barrier %arg0 : tensor<32x96xf32> + %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"a"}]> : tensor<32x96xf32> + return %1 : tensor<32x96xf32> +}