Skip to content

Commit ea9ba4f

Browse files
tomnatan30copybara-github
authored andcommitted
#sdy Allow all-slice and all-reduce to not have an input sharding, since the input can have a fully replicated sharding.
PiperOrigin-RevId: 737948972
1 parent 3e5e377 commit ea9ba4f

11 files changed

+127
-43
lines changed

shardy/dialect/sdy/ir/attrs.td

+10
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,11 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
731731
std::function<InFlightDiagnostic(StringRef)> emitError,
732732
bool checkDivisibility = true);
733733

734+
// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
735+
// being marked closed (cannot be further replicated/sharded).
736+
static TensorShardingAttr getFullyClosed(
737+
MLIRContext* context, int64_t rank, Attribute meshOrRef);
738+
734739
// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
735740
// being marked closed (cannot be further replicated/sharded).
736741
static TensorShardingAttr getFullyClosed(
@@ -750,6 +755,11 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
750755
MLIRContext* context, Attribute meshOrRef,
751756
ArrayRef<SmallVector<AxisRefAttr>> axesPerDim);
752757

758+
// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
759+
// being marked open (can be further replicated/sharded).
760+
static TensorShardingAttr getFullyOpen(
761+
MLIRContext* context, int64_t rank, Attribute meshOrRef);
762+
753763
// Builds a `TensorShardingAttr` with all dim shardings and replicated axes
754764
// being marked open (can be further replicated/sharded).
755765
static TensorShardingAttr getFullyOpen(

shardy/dialect/sdy/ir/dialect.cc

+29
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,12 @@ TensorShardingAttr TensorShardingAttr::getReplicated(StringRef axisName,
786786
getDimShardings(), newReplicatedAxes);
787787
}
788788

789+
TensorShardingAttr TensorShardingAttr::getFullyClosed(MLIRContext* context,
790+
int64_t rank,
791+
Attribute meshOrRef) {
792+
return getTensorShardingAttr(context, rank, meshOrRef, /*isClosed=*/true);
793+
}
794+
789795
TensorShardingAttr TensorShardingAttr::getFullyClosed(MLIRContext* context,
790796
int64_t rank,
791797
StringRef meshName) {
@@ -823,6 +829,13 @@ TensorShardingAttr TensorShardingAttr::getClosed(
823829
/*replicatedAxes=*/{});
824830
}
825831

832+
TensorShardingAttr TensorShardingAttr::getFullyOpen(MLIRContext* context,
833+
int64_t rank,
834+
Attribute meshOrRef) {
835+
return getTensorShardingAttr(context, rank, meshOrRef, /*isClosed=*/false);
836+
}
837+
838+
826839
TensorShardingAttr TensorShardingAttr::getFullyOpen(MLIRContext* context,
827840
int64_t rank,
828841
StringRef meshName) {
@@ -1393,6 +1406,14 @@ LogicalResult NamedComputationOp::inferReturnTypes(
13931406
return success();
13941407
}
13951408

1409+
//===----------------------------------------------------------------------===//
1410+
// AllSliceOp
1411+
//===----------------------------------------------------------------------===//
1412+
1413+
bool AllSliceOp::allowMissingInputSharding() { return true; }
1414+
1415+
Type AllSliceOp::getType() { return getResult().getType(); }
1416+
13961417
//===----------------------------------------------------------------------===//
13971418
// CollectivePermuteOp
13981419
//===----------------------------------------------------------------------===//
@@ -1401,6 +1422,14 @@ bool CollectivePermuteOp::allowDifferentMeshes() { return true; }
14011422

14021423
Type CollectivePermuteOp::getType() { return getResult().getType(); }
14031424

1425+
//===----------------------------------------------------------------------===//
1426+
// AllReduceOp
1427+
//===----------------------------------------------------------------------===//
1428+
1429+
bool AllReduceOp::allowMissingInputSharding() { return true; }
1430+
1431+
Type AllReduceOp::getType() { return getResult().getType(); }
1432+
14041433
} // namespace sdy
14051434
} // namespace mlir
14061435

shardy/dialect/sdy/ir/op_interface.td

+13-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ def Sdy_CollectiveOpInterface : OpInterface<"CollectiveOpInterface"> {
248248
outSharding attribute.
249249

250250
**Constraints:**
251-
- Operand must have a sharding.
251+
- Operand must have a sharding or `allowMissingInputSharding()` returns
252+
true.
252253
- `out_sharding` is valid w.r.t the corresponding type.
253254
- Operand and result sharding must have the same mesh if
254255
`allowDifferentMeshes()` returns false.
@@ -291,6 +292,17 @@ def Sdy_CollectiveOpInterface : OpInterface<"CollectiveOpInterface"> {
291292
/*args=*/(ins),
292293
/*methodBody=*/"",
293294
/*defaultImplementation=*/"return false;"
295+
>,
296+
InterfaceMethod<
297+
/*desc=*/[{
298+
Indicated whether the collective op allows the input to have no
299+
sharding, i.e, implicitly fully replicated.
300+
}],
301+
/*retType=*/"bool",
302+
/*methodName=*/"allowMissingInputSharding",
303+
/*args=*/(ins),
304+
/*methodBody=*/"",
305+
/*defaultImplementation=*/"return false;"
294306
>
295307
];
296308
let verify = [{

shardy/dialect/sdy/ir/ops.td

+6-2
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,9 @@ def Sdy_AllGatherOp : Sdy_Op<"all_gather",
493493
}
494494

495495
def Sdy_AllSliceOp : Sdy_Op<"all_slice",
496-
[SameOperandsAndResultType, InferTypeOpInterface, Sdy_CollectiveOpInterface]> {
496+
[SameOperandsAndResultType, InferTypeOpInterface,
497+
DeclareOpInterfaceMethods<Sdy_CollectiveOpInterface,
498+
/*methodOverrides=*/["allowMissingInputSharding"]>]> {
497499
let summary = "Performs a dynamic-slice operation along axes";
498500
let description = [{
499501
Slices chunks of a tensor along axes specified in `slicing_axes`. There is
@@ -633,7 +635,9 @@ def Sdy_CollectivePermuteOp : Sdy_Op<"collective_permute",
633635
}
634636

635637
def Sdy_AllReduceOp: Sdy_Op<"all_reduce",
636-
[SameOperandsAndResultType, InferTypeOpInterface, Sdy_CollectiveOpInterface]> {
638+
[SameOperandsAndResultType, InferTypeOpInterface,
639+
DeclareOpInterfaceMethods<Sdy_CollectiveOpInterface,
640+
/*methodOverrides=*/["allowMissingInputSharding"]>]> {
637641
let summary = "Perform an all-reduce comunication along axes";
638642
let description = [{
639643
Reduces chunks of a tensor along axes specified in `reduction_axes`.

shardy/dialect/sdy/ir/test/collective_parse_print.mlir

+14
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ func.func @all_slice4(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh
9393
return %0 : tensor<16x8xf32>
9494
}
9595

96+
// CHECK-LABEL: func @all_slice_missing_in_sharding
97+
func.func @all_slice_missing_in_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
98+
// CHECK-NEXT: sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh1, [{}, {"x"}]>
99+
%0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh1, [{}, {"x"}]> : tensor<16x8xf32>
100+
return %0 : tensor<16x8xf32>
101+
}
102+
96103
// CHECK-LABEL: func @all_slice_subaxis_exact_match
97104
func.func @all_slice_subaxis_exact_match(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3, [{"y"}, {}]>}) -> tensor<16x8xf32> {
98105
// CHECK-NEXT: sdy.all_slice [{}, {"x":(1)2}] %arg0 out_sharding=<@mesh3, [{"y"}, {"x":(1)2}]>
@@ -233,6 +240,13 @@ func.func @all_reduce(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh
233240
return %0 : tensor<16x2xf32>
234241
}
235242

243+
// CHECK-LABEL: func @all_reduce_missing_in_sharding
244+
func.func @all_reduce_missing_in_sharding(%arg0 : tensor<16x2xf32>) -> tensor<16x2xf32> {
245+
// CHECK-NEXT: sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh1, [{}, {}]> : tensor<16x2xf32>
246+
%0 = sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh1, [{}, {}]> : tensor<16x2xf32>
247+
return %0 : tensor<16x2xf32>
248+
}
249+
236250
sdy.mesh @mesh_xyzw = <["x"=2, "y"=2, "z"=2, "w"=2]>
237251

238252
// CHECK-LABEL: func @all_reduce_many_axes

shardy/dialect/sdy/ir/test/collective_verification.mlir

+10-19
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,6 @@ func.func @all_gather_with_incompatible_result_sharding_subaxis(%arg0 : tensor<1
132132

133133
// -----
134134

135-
sdy.mesh @mesh = <["x"=2, "y"=2]>
136-
137-
func.func @all_slice_on_operand_without_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
138-
// expected-error @+1 {{collective on operand without sharding}}
139-
%0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh, [{"y"}, {"x"}]> : tensor<16x8xf32>
140-
return %0 : tensor<16x8xf32>
141-
}
142-
143-
// -----
144-
145135
sdy.mesh @mesh1 = <["x"=2, "y"=2]>
146136
sdy.mesh @mesh2 = <["a"=2, "b"=2]>
147137

@@ -398,6 +388,16 @@ func.func @all_to_all_incompatible_result_sharding_non_moved_dim(%arg0 : tensor<
398388

399389
sdy.mesh @mesh = <["x"=2, "y"=2]>
400390

391+
func.func @collective_permute_on_operand_without_sharding(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
392+
// expected-error @+1 {{collective on operand without sharding}}
393+
%0 = sdy.collective_permute %arg0 out_sharding=<@mesh, [{}, {"x"}]> : tensor<16x8xf32>
394+
return %0 : tensor<16x8xf32>
395+
}
396+
397+
// -----
398+
399+
sdy.mesh @mesh = <["x"=2, "y"=2]>
400+
401401
func.func @collective_permute_invalid_out_sharding(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> {
402402
// expected-error @+1 {{duplicate axis ref: "x"}}
403403
%0 = sdy.collective_permute %arg0 out_sharding=<@mesh, [{"y", "x"}, {"x"}]> : tensor<16x8xf32>
@@ -519,15 +519,6 @@ sdy.mesh @mesh= <["x"=2, "y"=8, "z"=2]>
519519

520520
// -----
521521

522-
sdy.mesh @mesh = <["x"=2, "y"=2]>
523-
524-
func.func @all_reduce_on_operand_without_sharding(%arg0 : tensor<16x2xf32>) -> tensor<16x2xf32> {
525-
// expected-error @+1 {{'sdy.all_reduce' op collective on operand without sharding}}
526-
%0 = sdy.all_reduce {"x"} %arg0 out_sharding=<@mesh, [{"y"}, {}]> : tensor<16x2xf32>
527-
return %0 : tensor<16x2xf32>
528-
}
529-
// -----
530-
531522
sdy.mesh @mesh = <["x"=4, "y"=2]>
532523

533524
func.func @all_reduce_reduction_axes_can_be_merged(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<16x2xf32> {

shardy/dialect/sdy/ir/utils.cc

+11-4
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,23 @@ TensorShardingAttr getSharding(Value value) {
316316
});
317317
}
318318

319-
TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,
319+
TensorShardingAttr getOrCreateSharding(Value value, Attribute meshOrRef,
320320
const bool closedIfMissing) {
321321
if (TensorShardingAttr sharding = getSharding(value)) {
322322
return sharding;
323323
}
324324
return closedIfMissing
325325
? TensorShardingAttr::getFullyClosed(
326-
value.getContext(), getTensorRank(value), meshName)
327-
: TensorShardingAttr::getFullyOpen(value.getContext(),
328-
getTensorRank(value), meshName);
326+
value.getContext(), getTensorRank(value), meshOrRef)
327+
: TensorShardingAttr::getFullyOpen(
328+
value.getContext(), getTensorRank(value), meshOrRef);
329+
}
330+
331+
TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,
332+
const bool closedIfMissing) {
333+
return getOrCreateSharding(
334+
value, FlatSymbolRefAttr::get(value.getContext(), meshName),
335+
closedIfMissing);
329336
}
330337

331338
void setSharding(Value value, TensorShardingAttr sharding) {

shardy/dialect/sdy/ir/utils.h

+6
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ TensorShardingAttr getSharding(Value value);
253253
TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,
254254
bool closedIfMissing = false);
255255

256+
// Returns the sharding of the given `value`, or a fully open (closed) empty
257+
// `TensorShardingAttr` if `value` doesn't have a sharding and `closedIfMissing`
258+
// is false (true).
259+
TensorShardingAttr getOrCreateSharding(Value value, Attribute meshOrRef,
260+
bool closedIfMissing = false);
261+
256262
// Sets the `TensorShardingPerValueAttr` of the given `op`, but
257263
// replaces the sharding at the given `index` with the given `sharding`.
258264
//

shardy/dialect/sdy/ir/verifiers.cc

+15-12
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
#include <algorithm>
1919
#include <cstdint>
2020
#include <functional>
21-
#include <iterator>
2221
#include <numeric>
2322
#include <optional>
2423
#include <utility>
@@ -1039,8 +1038,9 @@ LogicalResult verifyCollectiveWithAxesPerDim(
10391038
DimensionShardingAttr operandDimSharding,
10401039
ArrayRef<AxisRefAttr> dimCollectiveAxes, int64_t dim, MeshAttr mesh)>
10411040
getExpectedResultDimSharding) {
1042-
TensorShardingAttr operandSharding = getSharding(op.getOperand());
10431041
TensorShardingAttr resultSharding = op.getOutSharding();
1042+
TensorShardingAttr operandSharding =
1043+
getOrCreateSharding(op.getOperand(), resultSharding.getMeshOrRef());
10441044
MeshAttr mesh = resultSharding.getMesh(op);
10451045
MeshAttr operandMesh = operandSharding.getMesh(op);
10461046

@@ -1285,8 +1285,9 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) {
12851285
return failure();
12861286
}
12871287
// 1. Verify operand has a sharding.
1288-
TensorShardingAttr operandSharding = getSharding(collectiveOp.getTensor());
1289-
if (!operandSharding) {
1288+
TensorShardingAttr optionalOperandSharding =
1289+
getSharding(collectiveOp.getTensor());
1290+
if (!collectiveOp.allowMissingInputSharding() && !optionalOperandSharding) {
12901291
return collectiveOp.emitOpError("collective on operand without sharding");
12911292
}
12921293

@@ -1302,21 +1303,22 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) {
13021303
// 3. Verify MeshAttr of result and operand is the same.
13031304
if (!collectiveOp.allowDifferentMeshes()) {
13041305
MeshAttr mesh = resultSharding.getMesh(collectiveOp);
1305-
MeshAttr operandMesh = operandSharding.getMesh(collectiveOp);
1306-
if (mesh != operandMesh) {
1306+
MeshAttr operandMesh = optionalOperandSharding
1307+
? optionalOperandSharding.getMesh(collectiveOp)
1308+
: nullptr;
1309+
if (operandMesh && mesh != operandMesh) {
13071310
return collectiveOp.emitOpError("result mesh does not match operand mesh")
13081311
.attachNote(collectiveOp.getTensor().getLoc())
13091312
<< "operand mesh: " << operandMesh;
13101313
}
13111314
}
13121315

13131316
// 4. Verify same rank of the result sharding and operand sharding.
1314-
auto resultDimShardings = resultSharding.getRank();
1315-
auto operandDimShardings = operandSharding.getRank();
1316-
if (resultDimShardings != operandDimShardings) {
1317+
if (optionalOperandSharding &&
1318+
resultSharding.getRank() != optionalOperandSharding.getRank()) {
13171319
return collectiveOp.emitOpError("result sharding has rank ")
1318-
<< resultDimShardings << " but operand sharding has rank "
1319-
<< operandDimShardings;
1320+
<< resultSharding.getRank() << " but operand sharding has rank "
1321+
<< optionalOperandSharding.getRank();
13201322
}
13211323
return success();
13221324
}
@@ -1373,8 +1375,9 @@ LogicalResult SdyDialect::verifyOperationAttribute(Operation* op,
13731375
}
13741376

13751377
LogicalResult AllReduceOp::verify() {
1376-
TensorShardingAttr operandSharding = getSharding(getOperand());
13771378
TensorShardingAttr resultSharding = getOutSharding();
1379+
TensorShardingAttr operandSharding =
1380+
getOrCreateSharding(getOperand(), resultSharding.getMeshOrRef());
13781381
MeshAttr mesh = resultSharding.getMesh(*this);
13791382
if (!operandSharding.areDimAxesEqual(resultSharding)) {
13801383
return emitOpError("operand and result sharding have different axes");

shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -1235,8 +1235,9 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
12351235
LogicalResult matchAndRewrite(
12361236
ReshardOp op, OpAdaptor adaptor,
12371237
ConversionPatternRewriter& rewriter) const override {
1238-
TensorShardingAttr inSharding = getSharding(adaptor.getInput());
12391238
TensorShardingAttr outSharding = adaptor.getSharding();
1239+
TensorShardingAttr inSharding =
1240+
getOrCreateSharding(adaptor.getInput(), outSharding.getMeshName());
12401241
// Here it's safe to assume that shardings' meshes have a name.
12411242
if (inSharding.getRank() != outSharding.getRank()) {
12421243
return rewriter.notifyMatchFailure(

shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir

+11-4
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,24 @@ func.func @all_slice_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.s
5757
return %0 : tensor<16x8xf32>
5858
}
5959

60+
// CHECK-LABEL: func @all_slice_with_subaxis
61+
func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
62+
// CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]>
63+
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32>
64+
return %0 : tensor<16x8xf32>
65+
}
66+
6067
// CHECK-LABEL: func @all_slice_minor_axis
6168
func.func @all_slice_minor_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> {
6269
// CHECK-NEXT: sdy.all_slice [{}, {"z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]>
6370
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32>
6471
return %0 : tensor<16x8xf32>
6572
}
6673

67-
// CHECK-LABEL: func @all_slice_with_subaxis
68-
func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"x":(1)2}, {"y"}]>}) -> tensor<16x8xf32> {
69-
// CHECK-NEXT: sdy.all_slice [{"x":(2)2}, {"z":(1)2}] %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]>
70-
%0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]> : tensor<16x8xf32>
74+
// CHECK-LABEL: func @all_slice_missing_input_sharding
75+
func.func @all_slice_missing_input_sharding(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{}, {}]>}) -> tensor<16x8xf32> {
76+
// CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]>
77+
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32>
7178
return %0 : tensor<16x8xf32>
7279
}
7380

0 commit comments

Comments
 (0)