@@ -18,7 +18,6 @@ limitations under the License.
18
18
#include < algorithm>
19
19
#include < cstdint>
20
20
#include < functional>
21
- #include < iterator>
22
21
#include < numeric>
23
22
#include < optional>
24
23
#include < utility>
@@ -1039,8 +1038,9 @@ LogicalResult verifyCollectiveWithAxesPerDim(
1039
1038
DimensionShardingAttr operandDimSharding,
1040
1039
ArrayRef<AxisRefAttr> dimCollectiveAxes, int64_t dim, MeshAttr mesh)>
1041
1040
getExpectedResultDimSharding) {
1042
- TensorShardingAttr operandSharding = getSharding (op.getOperand ());
1043
1041
TensorShardingAttr resultSharding = op.getOutSharding ();
1042
+ TensorShardingAttr operandSharding =
1043
+ getOrCreateSharding (op.getOperand (), resultSharding.getMeshOrRef ());
1044
1044
MeshAttr mesh = resultSharding.getMesh (op);
1045
1045
MeshAttr operandMesh = operandSharding.getMesh (op);
1046
1046
@@ -1285,8 +1285,9 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) {
1285
1285
return failure ();
1286
1286
}
1287
1287
// 1. Verify operand has a sharding.
1288
- TensorShardingAttr operandSharding = getSharding (collectiveOp.getTensor ());
1289
- if (!operandSharding) {
1288
+ TensorShardingAttr optionalOperandSharding =
1289
+ getSharding (collectiveOp.getTensor ());
1290
+ if (!collectiveOp.allowMissingInputSharding () && !optionalOperandSharding) {
1290
1291
return collectiveOp.emitOpError (" collective on operand without sharding" );
1291
1292
}
1292
1293
@@ -1302,21 +1303,22 @@ LogicalResult verifyCollectiveOp(Operation* rawOp) {
1302
1303
// 3. Verify MeshAttr of result and operand is the same.
1303
1304
if (!collectiveOp.allowDifferentMeshes ()) {
1304
1305
MeshAttr mesh = resultSharding.getMesh (collectiveOp);
1305
- MeshAttr operandMesh = operandSharding.getMesh (collectiveOp);
1306
- if (mesh != operandMesh) {
1306
+ MeshAttr operandMesh = optionalOperandSharding
1307
+ ? optionalOperandSharding.getMesh (collectiveOp)
1308
+ : nullptr ;
1309
+ if (operandMesh && mesh != operandMesh) {
1307
1310
return collectiveOp.emitOpError (" result mesh does not match operand mesh" )
1308
1311
.attachNote (collectiveOp.getTensor ().getLoc ())
1309
1312
<< " operand mesh: " << operandMesh;
1310
1313
}
1311
1314
}
1312
1315
1313
1316
// 4. Verify same rank of the result sharding and operand sharding.
1314
- auto resultDimShardings = resultSharding.getRank ();
1315
- auto operandDimShardings = operandSharding.getRank ();
1316
- if (resultDimShardings != operandDimShardings) {
1317
+ if (optionalOperandSharding &&
1318
+ resultSharding.getRank () != optionalOperandSharding.getRank ()) {
1317
1319
return collectiveOp.emitOpError (" result sharding has rank " )
1318
- << resultDimShardings << " but operand sharding has rank "
1319
- << operandDimShardings ;
1320
+ << resultSharding. getRank () << " but operand sharding has rank "
1321
+ << optionalOperandSharding. getRank () ;
1320
1322
}
1321
1323
return success ();
1322
1324
}
@@ -1373,8 +1375,9 @@ LogicalResult SdyDialect::verifyOperationAttribute(Operation* op,
1373
1375
}
1374
1376
1375
1377
LogicalResult AllReduceOp::verify () {
1376
- TensorShardingAttr operandSharding = getSharding (getOperand ());
1377
1378
TensorShardingAttr resultSharding = getOutSharding ();
1379
+ TensorShardingAttr operandSharding =
1380
+ getOrCreateSharding (getOperand (), resultSharding.getMeshOrRef ());
1378
1381
MeshAttr mesh = resultSharding.getMesh (*this );
1379
1382
if (!operandSharding.areDimAxesEqual (resultSharding)) {
1380
1383
return emitOpError (" operand and result sharding have different axes" );
0 commit comments