diff --git a/shardy/dialect/sdy/ir/attrs.td b/shardy/dialect/sdy/ir/attrs.td index cbb529b4..68fbd954 100644 --- a/shardy/dialect/sdy/ir/attrs.td +++ b/shardy/dialect/sdy/ir/attrs.td @@ -731,6 +731,11 @@ def Sdy_TensorSharding : AttrDef { std::function emitError, bool checkDivisibility = true); + // Builds a `TensorShardingAttr` with all dim shardings and replicated axes + // being marked closed (cannot be further replicated/sharded). + static TensorShardingAttr getFullyClosed( + MLIRContext* context, int64_t rank, Attribute meshOrRef); + // Builds a `TensorShardingAttr` with all dim shardings and replicated axes // being marked closed (cannot be further replicated/sharded). static TensorShardingAttr getFullyClosed( @@ -750,6 +755,11 @@ def Sdy_TensorSharding : AttrDef { MLIRContext* context, Attribute meshOrRef, ArrayRef> axesPerDim); + // Builds a `TensorShardingAttr` with all dim shardings and replicated axes + // being marked open (can be further replicated/sharded). + static TensorShardingAttr getFullyOpen( + MLIRContext* context, int64_t rank, Attribute meshOrRef); + // Builds a `TensorShardingAttr` with all dim shardings and replicated axes // being marked open (can be further replicated/sharded). static TensorShardingAttr getFullyOpen( diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc index ad886e3a..1aa7c15d 100644 --- a/shardy/dialect/sdy/ir/dialect.cc +++ b/shardy/dialect/sdy/ir/dialect.cc @@ -786,6 +786,12 @@ TensorShardingAttr TensorShardingAttr::getReplicated(StringRef axisName, getDimShardings(), newReplicatedAxes); } +TensorShardingAttr TensorShardingAttr::getFullyClosed(MLIRContext* context, + int64_t rank, + Attribute meshOrRef) { + return getTensorShardingAttr(context, rank, meshOrRef, /*isClosed=*/true); +} + TensorShardingAttr TensorShardingAttr::getFullyClosed(MLIRContext* context, int64_t rank, StringRef meshName) { @@ -823,6 +829,13 @@ TensorShardingAttr TensorShardingAttr::getClosed( /*replicatedAxes=*/{}); } +TensorShardingAttr TensorShardingAttr::getFullyOpen(MLIRContext* context, + int64_t rank, + Attribute meshOrRef) { + return getTensorShardingAttr(context, rank, meshOrRef, /*isClosed=*/false); +} + + TensorShardingAttr TensorShardingAttr::getFullyOpen(MLIRContext* context, int64_t rank, StringRef meshName) { @@ -1393,6 +1406,14 @@ LogicalResult NamedComputationOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// AllSliceOp +//===----------------------------------------------------------------------===// + +bool AllSliceOp::allowMissingInputSharding() { return true; } + +Type AllSliceOp::getType() { return getResult().getType(); } + //===----------------------------------------------------------------------===// // CollectivePermuteOp //===----------------------------------------------------------------------===// @@ -1401,6 +1422,14 @@ bool CollectivePermuteOp::allowDifferentMeshes() { return true; } Type CollectivePermuteOp::getType() { return getResult().getType(); } +//===----------------------------------------------------------------------===// +// AllReduceOp +//===----------------------------------------------------------------------===// + +bool AllReduceOp::allowMissingInputSharding() { return true; } + +Type AllReduceOp::getType() { return getResult().getType(); } + } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/ir/op_interface.td b/shardy/dialect/sdy/ir/op_interface.td index f52e7a2b..e069f74f 100644 --- a/shardy/dialect/sdy/ir/op_interface.td +++ b/shardy/dialect/sdy/ir/op_interface.td @@ -248,7 +248,8 @@ def Sdy_CollectiveOpInterface : OpInterface<"CollectiveOpInterface"> { outSharding attribute. **Constraints:** - - Operand must have a sharding. + - Operand must have a sharding or `allowMissingInputSharding()` returns + true. - `out_sharding` is valid w.r.t the corresponding type. - Operand and result sharding must have the same mesh if `allowDifferentMeshes()` returns false. @@ -291,6 +292,17 @@ def Sdy_CollectiveOpInterface : OpInterface<"CollectiveOpInterface"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/"return false;" + >, + InterfaceMethod< + /*desc=*/[{ + Indicated whether the collective op allows the input to have no + sharding, i.e, implicitly fully replicated. + }], + /*retType=*/"bool", + /*methodName=*/"allowMissingInputSharding", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"return false;" > ]; let verify = [{ diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td index 15fdd4dd..6f8d0ac7 100644 --- a/shardy/dialect/sdy/ir/ops.td +++ b/shardy/dialect/sdy/ir/ops.td @@ -493,7 +493,9 @@ def Sdy_AllGatherOp : Sdy_Op<"all_gather", } def Sdy_AllSliceOp : Sdy_Op<"all_slice", - [SameOperandsAndResultType, InferTypeOpInterface, Sdy_CollectiveOpInterface]> { + [SameOperandsAndResultType, InferTypeOpInterface, + DeclareOpInterfaceMethods]> { let summary = "Performs a dynamic-slice operation along axes"; let description = [{ Slices chunks of a tensor along axes specified in `slicing_axes`. There is @@ -633,7 +635,9 @@ def Sdy_CollectivePermuteOp : Sdy_Op<"collective_permute", } def Sdy_AllReduceOp: Sdy_Op<"all_reduce", - [SameOperandsAndResultType, InferTypeOpInterface, Sdy_CollectiveOpInterface]> { + [SameOperandsAndResultType, InferTypeOpInterface, + DeclareOpInterfaceMethods]> { let summary = "Perform an all-reduce comunication along axes"; let description = [{ Reduces chunks of a tensor along axes specified in `reduction_axes`. diff --git a/shardy/dialect/sdy/ir/test/collective_parse_print.mlir b/shardy/dialect/sdy/ir/test/collective_parse_print.mlir index c9847ff6..fc14cdcc 100644 --- a/shardy/dialect/sdy/ir/test/collective_parse_print.mlir +++ b/shardy/dialect/sdy/ir/test/collective_parse_print.mlir @@ -93,6 +93,13 @@ func.func @all_slice4(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh return %0 : tensor<16x8xf32> } +// CHECK-LABEL: func @all_slice_missing_in_sharding +func.func @all_slice_missing_in_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK-NEXT: sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh1, [{}, {"x"}]> + %0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh1, [{}, {"x"}]> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + // CHECK-LABEL: func @all_slice_subaxis_exact_match func.func @all_slice_subaxis_exact_match(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3, [{"y"}, {}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_slice [{}, {"x":(1)2}] %arg0 out_sharding=<@mesh3, [{"y"}, {"x":(1)2}]> @@ -233,6 +240,13 @@ func.func @all_reduce(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh return %0 : tensor<16x2xf32> } +// CHECK-LABEL: func @all_reduce_missing_in_sharding +func.func @all_reduce_missing_in_sharding(%arg0 : tensor<16x2xf32>) -> tensor<16x2xf32> { + // CHECK-NEXT: sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh1, [{}, {}]> : tensor<16x2xf32> + %0 = sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh1, [{}, {}]> : tensor<16x2xf32> + return %0 : tensor<16x2xf32> +} + sdy.mesh @mesh_xyzw = <["x"=2, "y"=2, "z"=2, "w"=2]> // CHECK-LABEL: func @all_reduce_many_axes diff --git a/shardy/dialect/sdy/ir/test/collective_verification.mlir b/shardy/dialect/sdy/ir/test/collective_verification.mlir index 7194c7a4..ed1fbc5d 100644 --- a/shardy/dialect/sdy/ir/test/collective_verification.mlir +++ b/shardy/dialect/sdy/ir/test/collective_verification.mlir @@ -132,16 +132,6 @@ func.func @all_gather_with_incompatible_result_sharding_subaxis(%arg0 : tensor<1 // ----- -sdy.mesh @mesh = <["x"=2, "y"=2]> - -func.func @all_slice_on_operand_without_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { - // expected-error @+1 {{collective on operand without sharding}} - %0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh, [{"y"}, {"x"}]> : tensor<16x8xf32> - return %0 : tensor<16x8xf32> -} - -// ----- - sdy.mesh @mesh1 = <["x"=2, "y"=2]> sdy.mesh @mesh2 = <["a"=2, "b"=2]> @@ -398,6 +388,16 @@ func.func @all_to_all_incompatible_result_sharding_non_moved_dim(%arg0 : tensor< sdy.mesh @mesh = <["x"=2, "y"=2]> +func.func @collective_permute_on_operand_without_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { + // expected-error @+1 {{collective on operand without sharding}} + %0 = sdy.collective_permute %arg0 out_sharding=<@mesh, [{}, {"x"}]> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["x"=2, "y"=2]> + func.func @collective_permute_invalid_out_sharding(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> { // expected-error @+1 {{duplicate axis ref: "x"}} %0 = sdy.collective_permute %arg0 out_sharding=<@mesh, [{"y", "x"}, {"x"}]> : tensor<16x8xf32> @@ -519,15 +519,6 @@ sdy.mesh @mesh= <["x"=2, "y"=8, "z"=2]> // ----- -sdy.mesh @mesh = <["x"=2, "y"=2]> - -func.func @all_reduce_on_operand_without_sharding(%arg0 : tensor<16x2xf32>) -> tensor<16x2xf32> { - // expected-error @+1 {{'sdy.all_reduce' op collective on operand without sharding}} - %0 = sdy.all_reduce {"x"} %arg0 out_sharding=<@mesh, [{"y"}, {}]> : tensor<16x2xf32> - return %0 : tensor<16x2xf32> -} -// ----- - sdy.mesh @mesh = <["x"=4, "y"=2]> func.func @all_reduce_reduction_axes_can_be_merged(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<16x2xf32> { diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc index d91ba27f..e66a586c 100644 --- a/shardy/dialect/sdy/ir/utils.cc +++ b/shardy/dialect/sdy/ir/utils.cc @@ -316,16 +316,23 @@ TensorShardingAttr getSharding(Value value) { }); } -TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName, +TensorShardingAttr getOrCreateSharding(Value value, Attribute meshOrRef, const bool closedIfMissing) { if (TensorShardingAttr sharding = getSharding(value)) { return sharding; } return closedIfMissing ? TensorShardingAttr::getFullyClosed( - value.getContext(), getTensorRank(value), meshName) - : TensorShardingAttr::getFullyOpen(value.getContext(), - getTensorRank(value), meshName); + value.getContext(), getTensorRank(value), meshOrRef) + : TensorShardingAttr::getFullyOpen( + value.getContext(), getTensorRank(value), meshOrRef); +} + +TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName, + const bool closedIfMissing) { + return getOrCreateSharding( + value, FlatSymbolRefAttr::get(value.getContext(), meshName), + closedIfMissing); } void setSharding(Value value, TensorShardingAttr sharding) { diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h index ef589282..ebec983a 100644 --- a/shardy/dialect/sdy/ir/utils.h +++ b/shardy/dialect/sdy/ir/utils.h @@ -253,6 +253,12 @@ TensorShardingAttr getSharding(Value value); TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName, bool closedIfMissing = false); +// Returns the sharding of the given `value`, or a fully open (closed) empty +// `TensorShardingAttr` if `value` doesn't have a sharding and `closedIfMissing` +// is false (true). +TensorShardingAttr getOrCreateSharding(Value value, Attribute meshOrRef, + bool closedIfMissing = false); + // Sets the `TensorShardingPerValueAttr` of the given `op`, but // replaces the sharding at the given `index` with the given `sharding`. // diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc index a7c1605e..90cc7c88 100644 --- a/shardy/dialect/sdy/ir/verifiers.cc +++ b/shardy/dialect/sdy/ir/verifiers.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -1039,8 +1038,9 @@ LogicalResult verifyCollectiveWithAxesPerDim( DimensionShardingAttr operandDimSharding, ArrayRef dimCollectiveAxes, int64_t dim, MeshAttr mesh)> getExpectedResultDimSharding) { - TensorShardingAttr operandSharding = getSharding(op.getOperand()); TensorShardingAttr resultSharding = op.getOutSharding(); + TensorShardingAttr operandSharding = + getOrCreateSharding(op.getOperand(), resultSharding.getMeshOrRef()); MeshAttr mesh = resultSharding.getMesh(op); MeshAttr operandMesh = operandSharding.getMesh(op); @@ -1285,8 +1285,9 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) { return failure(); } // 1. Verify operand has a sharding. - TensorShardingAttr operandSharding = getSharding(collectiveOp.getTensor()); - if (!operandSharding) { + TensorShardingAttr optionalOperandSharding = + getSharding(collectiveOp.getTensor()); + if (!collectiveOp.allowMissingInputSharding() && !optionalOperandSharding) { return collectiveOp.emitOpError("collective on operand without sharding"); } @@ -1302,8 +1303,10 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) { // 3. Verify MeshAttr of result and operand is the same. if (!collectiveOp.allowDifferentMeshes()) { MeshAttr mesh = resultSharding.getMesh(collectiveOp); - MeshAttr operandMesh = operandSharding.getMesh(collectiveOp); - if (mesh != operandMesh) { + MeshAttr operandMesh = optionalOperandSharding + ? optionalOperandSharding.getMesh(collectiveOp) + : nullptr; + if (operandMesh && mesh != operandMesh) { return collectiveOp.emitOpError("result mesh does not match operand mesh") .attachNote(collectiveOp.getTensor().getLoc()) << "operand mesh: " << operandMesh; @@ -1311,12 +1314,11 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) { } // 4. Verify same rank of the result sharding and operand sharding. - auto resultDimShardings = resultSharding.getRank(); - auto operandDimShardings = operandSharding.getRank(); - if (resultDimShardings != operandDimShardings) { + if (optionalOperandSharding && + resultSharding.getRank() != optionalOperandSharding.getRank()) { return collectiveOp.emitOpError("result sharding has rank ") - << resultDimShardings << " but operand sharding has rank " - << operandDimShardings; + << resultSharding.getRank() << " but operand sharding has rank " + << optionalOperandSharding.getRank(); } return success(); } @@ -1373,8 +1375,9 @@ LogicalResult SdyDialect::verifyOperationAttribute(Operation* op, } LogicalResult AllReduceOp::verify() { - TensorShardingAttr operandSharding = getSharding(getOperand()); TensorShardingAttr resultSharding = getOutSharding(); + TensorShardingAttr operandSharding = + getOrCreateSharding(getOperand(), resultSharding.getMeshOrRef()); MeshAttr mesh = resultSharding.getMesh(*this); if (!operandSharding.areDimAxesEqual(resultSharding)) { return emitOpError("operand and result sharding have different axes"); diff --git a/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc b/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc index 691099b8..d9b301bc 100644 --- a/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc +++ b/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc @@ -1235,8 +1235,9 @@ class ReshardPattern : public OpConversionPattern { LogicalResult matchAndRewrite( ReshardOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - TensorShardingAttr inSharding = getSharding(adaptor.getInput()); TensorShardingAttr outSharding = adaptor.getSharding(); + TensorShardingAttr inSharding = + getOrCreateSharding(adaptor.getInput(), outSharding.getMeshName()); // Here it's safe to assume that shardings' meshes have a name. if (inSharding.getRank() != outSharding.getRank()) { return rewriter.notifyMatchFailure( diff --git a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir index 15a9eb82..3591896d 100644 --- a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir +++ b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir @@ -57,6 +57,13 @@ func.func @all_slice_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.s return %0 : tensor<16x8xf32> } +// CHECK-LABEL: func @all_slice_with_subaxis +func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]> + %0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + // CHECK-LABEL: func @all_slice_minor_axis func.func @all_slice_minor_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_slice [{}, {"z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]> @@ -64,10 +71,10 @@ func.func @all_slice_minor_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.shar return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @all_slice_with_subaxis -func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"x":(1)2}, {"y"}]>}) -> tensor<16x8xf32> { - // CHECK-NEXT: sdy.all_slice [{"x":(2)2}, {"z":(1)2}] %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]> - %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]> : tensor<16x8xf32> +// CHECK-LABEL: func @all_slice_missing_input_sharding +func.func @all_slice_missing_input_sharding(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{}, {}]>}) -> tensor<16x8xf32> { + // CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]> + %0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> }