Skip to content

Commit a0d337e

Browse files
Introduce isFullyReplicated() for TensorShardingAttr
PiperOrigin-RevId: 658095168
1 parent 71f2363 commit a0d337e

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

shardy/dialect/sdy/ir/attrs.td

+7
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,13 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
370370
return getDimSharding(dim).getIsClosed();
371371
}
372372

373+
bool isFullyReplicated() const {
374+
return llvm::all_of(getDimShardings(),
375+
[](const DimensionShardingAttr dimSharding) {
376+
return dimSharding.emptyAxes();
377+
});
378+
}
379+
373380
StringRef getMeshName() const {
374381
return getMeshSymName().getValue();
375382
}

0 commit comments

Comments
 (0)