Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy Allow all-slice and all-reduce to not have an input sharding, since the input can have a fully replicated sharding. #425

Merged
merged 1 commit into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,11 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
std::function<InFlightDiagnostic(StringRef)> emitError,
bool checkDivisibility = true);

// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
// being marked closed (cannot be further replicated/sharded).
static TensorShardingAttr getFullyClosed(
MLIRContext* context, int64_t rank, Attribute meshOrRef);

// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
// being marked closed (cannot be further replicated/sharded).
static TensorShardingAttr getFullyClosed(
Expand All @@ -750,6 +755,11 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
MLIRContext* context, Attribute meshOrRef,
ArrayRef<SmallVector<AxisRefAttr>> axesPerDim);

// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
// being marked open (can be further replicated/sharded).
static TensorShardingAttr getFullyOpen(
MLIRContext* context, int64_t rank, Attribute meshOrRef);

// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
// being marked open (can be further replicated/sharded).
static TensorShardingAttr getFullyOpen(
Expand Down
29 changes: 29 additions & 0 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,12 @@ TensorShardingAttr TensorShardingAttr::getReplicated(StringRef axisName,
getDimShardings(), newReplicatedAxes);
}

TensorShardingAttr TensorShardingAttr::getFullyClosed(MLIRContext* context,
int64_t rank,
Attribute meshOrRef) {
return getTensorShardingAttr(context, rank, meshOrRef, /*isClosed=*/true);
}

TensorShardingAttr TensorShardingAttr::getFullyClosed(MLIRContext* context,
int64_t rank,
StringRef meshName) {
Expand Down Expand Up @@ -823,6 +829,13 @@ TensorShardingAttr TensorShardingAttr::getClosed(
/*replicatedAxes=*/{});
}

TensorShardingAttr TensorShardingAttr::getFullyOpen(MLIRContext* context,
int64_t rank,
Attribute meshOrRef) {
return getTensorShardingAttr(context, rank, meshOrRef, /*isClosed=*/false);
}


TensorShardingAttr TensorShardingAttr::getFullyOpen(MLIRContext* context,
int64_t rank,
StringRef meshName) {
Expand Down Expand Up @@ -1393,6 +1406,14 @@ LogicalResult NamedComputationOp::inferReturnTypes(
return success();
}

//===----------------------------------------------------------------------===//
// AllSliceOp
//===----------------------------------------------------------------------===//

bool AllSliceOp::allowMissingInputSharding() { return true; }

Type AllSliceOp::getType() { return getResult().getType(); }

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//
Expand All @@ -1401,6 +1422,14 @@ bool CollectivePermuteOp::allowDifferentMeshes() { return true; }

Type CollectivePermuteOp::getType() { return getResult().getType(); }

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//

bool AllReduceOp::allowMissingInputSharding() { return true; }

Type AllReduceOp::getType() { return getResult().getType(); }

} // namespace sdy
} // namespace mlir

Expand Down
14 changes: 13 additions & 1 deletion shardy/dialect/sdy/ir/op_interface.td
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def Sdy_CollectiveOpInterface : OpInterface<"CollectiveOpInterface"> {
outSharding attribute.

**Constraints:**
- Operand must have a sharding.
- Operand must have a sharding or `allowMissingInputSharding()` returns
true.
- `out_sharding` is valid w.r.t the corresponding type.
- Operand and result sharding must have the same mesh if
`allowDifferentMeshes()` returns false.
Expand Down Expand Up @@ -291,6 +292,17 @@ def Sdy_CollectiveOpInterface : OpInterface<"CollectiveOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return false;"
>,
InterfaceMethod<
/*desc=*/[{
Indicated whether the collective op allows the input to have no
sharding, i.e, implicitly fully replicated.
}],
/*retType=*/"bool",
/*methodName=*/"allowMissingInputSharding",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return false;"
>
];
let verify = [{
Expand Down
8 changes: 6 additions & 2 deletions shardy/dialect/sdy/ir/ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def Sdy_AllGatherOp : Sdy_Op<"all_gather",
}

def Sdy_AllSliceOp : Sdy_Op<"all_slice",
[SameOperandsAndResultType, InferTypeOpInterface, Sdy_CollectiveOpInterface]> {
[SameOperandsAndResultType, InferTypeOpInterface,
DeclareOpInterfaceMethods<Sdy_CollectiveOpInterface,
/*methodOverrides=*/["allowMissingInputSharding"]>]> {
let summary = "Performs a dynamic-slice operation along axes";
let description = [{
Slices chunks of a tensor along axes specified in `slicing_axes`. There is
Expand Down Expand Up @@ -633,7 +635,9 @@ def Sdy_CollectivePermuteOp : Sdy_Op<"collective_permute",
}

def Sdy_AllReduceOp: Sdy_Op<"all_reduce",
[SameOperandsAndResultType, InferTypeOpInterface, Sdy_CollectiveOpInterface]> {
[SameOperandsAndResultType, InferTypeOpInterface,
DeclareOpInterfaceMethods<Sdy_CollectiveOpInterface,
/*methodOverrides=*/["allowMissingInputSharding"]>]> {
let summary = "Perform an all-reduce comunication along axes";
let description = [{
Reduces chunks of a tensor along axes specified in `reduction_axes`.
Expand Down
14 changes: 14 additions & 0 deletions shardy/dialect/sdy/ir/test/collective_parse_print.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ func.func @all_slice4(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @all_slice_missing_in_sharding
func.func @all_slice_missing_in_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh1, [{}, {"x"}]>
%0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh1, [{}, {"x"}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @all_slice_subaxis_exact_match
func.func @all_slice_subaxis_exact_match(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3, [{"y"}, {}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_slice [{}, {"x":(1)2}] %arg0 out_sharding=<@mesh3, [{"y"}, {"x":(1)2}]>
Expand Down Expand Up @@ -233,6 +240,13 @@ func.func @all_reduce(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh
return %0 : tensor<16x2xf32>
}

// CHECK-LABEL: func @all_reduce_missing_in_sharding
func.func @all_reduce_missing_in_sharding(%arg0 : tensor<16x2xf32>) -> tensor<16x2xf32> {
// CHECK-NEXT: sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh1, [{}, {}]> : tensor<16x2xf32>
%0 = sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh1, [{}, {}]> : tensor<16x2xf32>
return %0 : tensor<16x2xf32>
}

sdy.mesh @mesh_xyzw = <["x"=2, "y"=2, "z"=2, "w"=2]>

// CHECK-LABEL: func @all_reduce_many_axes
Expand Down
29 changes: 10 additions & 19 deletions shardy/dialect/sdy/ir/test/collective_verification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,6 @@ func.func @all_gather_with_incompatible_result_sharding_subaxis(%arg0 : tensor<1

// -----

sdy.mesh @mesh = <["x"=2, "y"=2]>

func.func @all_slice_on_operand_without_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
// expected-error @+1 {{collective on operand without sharding}}
%0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh, [{"y"}, {"x"}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// -----

sdy.mesh @mesh1 = <["x"=2, "y"=2]>
sdy.mesh @mesh2 = <["a"=2, "b"=2]>

Expand Down Expand Up @@ -398,6 +388,16 @@ func.func @all_to_all_incompatible_result_sharding_non_moved_dim(%arg0 : tensor<

sdy.mesh @mesh = <["x"=2, "y"=2]>

func.func @collective_permute_on_operand_without_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
// expected-error @+1 {{collective on operand without sharding}}
%0 = sdy.collective_permute %arg0 out_sharding=<@mesh, [{}, {"x"}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// -----

sdy.mesh @mesh = <["x"=2, "y"=2]>

func.func @collective_permute_invalid_out_sharding(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> {
// expected-error @+1 {{duplicate axis ref: "x"}}
%0 = sdy.collective_permute %arg0 out_sharding=<@mesh, [{"y", "x"}, {"x"}]> : tensor<16x8xf32>
Expand Down Expand Up @@ -519,15 +519,6 @@ sdy.mesh @mesh= <["x"=2, "y"=8, "z"=2]>

// -----

sdy.mesh @mesh = <["x"=2, "y"=2]>

func.func @all_reduce_on_operand_without_sharding(%arg0 : tensor<16x2xf32>) -> tensor<16x2xf32> {
// expected-error @+1 {{'sdy.all_reduce' op collective on operand without sharding}}
%0 = sdy.all_reduce {"x"} %arg0 out_sharding=<@mesh, [{"y"}, {}]> : tensor<16x2xf32>
return %0 : tensor<16x2xf32>
}
// -----

sdy.mesh @mesh = <["x"=4, "y"=2]>

func.func @all_reduce_reduction_axes_can_be_merged(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<16x2xf32> {
Expand Down
15 changes: 11 additions & 4 deletions shardy/dialect/sdy/ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,16 +316,23 @@ TensorShardingAttr getSharding(Value value) {
});
}

TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,
TensorShardingAttr getOrCreateSharding(Value value, Attribute meshOrRef,
const bool closedIfMissing) {
if (TensorShardingAttr sharding = getSharding(value)) {
return sharding;
}
return closedIfMissing
? TensorShardingAttr::getFullyClosed(
value.getContext(), getTensorRank(value), meshName)
: TensorShardingAttr::getFullyOpen(value.getContext(),
getTensorRank(value), meshName);
value.getContext(), getTensorRank(value), meshOrRef)
: TensorShardingAttr::getFullyOpen(
value.getContext(), getTensorRank(value), meshOrRef);
}

TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,
const bool closedIfMissing) {
return getOrCreateSharding(
value, FlatSymbolRefAttr::get(value.getContext(), meshName),
closedIfMissing);
}

void setSharding(Value value, TensorShardingAttr sharding) {
Expand Down
6 changes: 6 additions & 0 deletions shardy/dialect/sdy/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ TensorShardingAttr getSharding(Value value);
TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,
bool closedIfMissing = false);

// Returns the sharding of the given `value`, or a fully open (closed) empty
// `TensorShardingAttr` if `value` doesn't have a sharding and `closedIfMissing`
// is false (true).
TensorShardingAttr getOrCreateSharding(Value value, Attribute meshOrRef,
bool closedIfMissing = false);

// Sets the `TensorShardingPerValueAttr` of the given `op`, but
// replaces the sharding at the given `index` with the given `sharding`.
//
Expand Down
27 changes: 15 additions & 12 deletions shardy/dialect/sdy/ir/verifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
Expand Down Expand Up @@ -1039,8 +1038,9 @@ LogicalResult verifyCollectiveWithAxesPerDim(
DimensionShardingAttr operandDimSharding,
ArrayRef<AxisRefAttr> dimCollectiveAxes, int64_t dim, MeshAttr mesh)>
getExpectedResultDimSharding) {
TensorShardingAttr operandSharding = getSharding(op.getOperand());
TensorShardingAttr resultSharding = op.getOutSharding();
TensorShardingAttr operandSharding =
getOrCreateSharding(op.getOperand(), resultSharding.getMeshOrRef());
MeshAttr mesh = resultSharding.getMesh(op);
MeshAttr operandMesh = operandSharding.getMesh(op);

Expand Down Expand Up @@ -1285,8 +1285,9 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) {
return failure();
}
// 1. Verify operand has a sharding.
TensorShardingAttr operandSharding = getSharding(collectiveOp.getTensor());
if (!operandSharding) {
TensorShardingAttr optionalOperandSharding =
getSharding(collectiveOp.getTensor());
if (!collectiveOp.allowMissingInputSharding() && !optionalOperandSharding) {
return collectiveOp.emitOpError("collective on operand without sharding");
}

Expand All @@ -1302,21 +1303,22 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) {
// 3. Verify MeshAttr of result and operand is the same.
if (!collectiveOp.allowDifferentMeshes()) {
MeshAttr mesh = resultSharding.getMesh(collectiveOp);
MeshAttr operandMesh = operandSharding.getMesh(collectiveOp);
if (mesh != operandMesh) {
MeshAttr operandMesh = optionalOperandSharding
? optionalOperandSharding.getMesh(collectiveOp)
: nullptr;
if (operandMesh && mesh != operandMesh) {
return collectiveOp.emitOpError("result mesh does not match operand mesh")
.attachNote(collectiveOp.getTensor().getLoc())
<< "operand mesh: " << operandMesh;
}
}

// 4. Verify same rank of the result sharding and operand sharding.
auto resultDimShardings = resultSharding.getRank();
auto operandDimShardings = operandSharding.getRank();
if (resultDimShardings != operandDimShardings) {
if (optionalOperandSharding &&
resultSharding.getRank() != optionalOperandSharding.getRank()) {
return collectiveOp.emitOpError("result sharding has rank ")
<< resultDimShardings << " but operand sharding has rank "
<< operandDimShardings;
<< resultSharding.getRank() << " but operand sharding has rank "
<< optionalOperandSharding.getRank();
}
return success();
}
Expand Down Expand Up @@ -1373,8 +1375,9 @@ LogicalResult SdyDialect::verifyOperationAttribute(Operation* op,
}

LogicalResult AllReduceOp::verify() {
TensorShardingAttr operandSharding = getSharding(getOperand());
TensorShardingAttr resultSharding = getOutSharding();
TensorShardingAttr operandSharding =
getOrCreateSharding(getOperand(), resultSharding.getMeshOrRef());
MeshAttr mesh = resultSharding.getMesh(*this);
if (!operandSharding.areDimAxesEqual(resultSharding)) {
return emitOpError("operand and result sharding have different axes");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1235,8 +1235,9 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
LogicalResult matchAndRewrite(
ReshardOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
TensorShardingAttr inSharding = getSharding(adaptor.getInput());
TensorShardingAttr outSharding = adaptor.getSharding();
TensorShardingAttr inSharding =
getOrCreateSharding(adaptor.getInput(), outSharding.getMeshName());
// Here it's safe to assume that shardings' meshes have a name.
if (inSharding.getRank() != outSharding.getRank()) {
return rewriter.notifyMatchFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,24 @@ func.func @all_slice_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.s
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @all_slice_with_subaxis
func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]>
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @all_slice_minor_axis
func.func @all_slice_minor_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_slice [{}, {"z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]>
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @all_slice_with_subaxis
func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"x":(1)2}, {"y"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_slice [{"x":(2)2}, {"z":(1)2}] %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]>
%0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]> : tensor<16x8xf32>
// CHECK-LABEL: func @all_slice_missing_input_sharding
func.func @all_slice_missing_input_sharding(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{}, {}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]>
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

Expand Down