@@ -681,6 +681,52 @@ bool shouldReshard(TensorShardingAttr sourceSharding,
681
681
return sourceSharding != targetSharding;
682
682
}
683
683
684
+ void insertExplicitReshardsOnFuncReturn (Operation* op, func::FuncOp& funcOp,
685
+ IRRewriter& rewriter) {
686
+ rewriter.setInsertionPoint (op);
687
+ for (const auto & [index , opOperand] : llvm::enumerate (op->getOpOperands ())) {
688
+ Value operand = opOperand.get ();
689
+ TensorShardingAttr funcResultSharding =
690
+ getFuncResultSharding (funcOp, index );
691
+ TensorShardingAttr operandSharding = getSharding (operand);
692
+ if (shouldReshard (operandSharding, funcResultSharding)) {
693
+ auto reshardOp = rewriter.create <ReshardOp>(
694
+ operand.getLoc (), operand,
695
+ funcResultSharding
696
+ ? funcResultSharding
697
+ // Since it should reshard and `funcResultSharding` is empty,
698
+ // `operandSharding` is guaranteed to be nonempty.
699
+ : TensorShardingAttr::getFullyClosedLike (operandSharding));
700
+ opOperand.set (reshardOp);
701
+ }
702
+ }
703
+ }
704
+
705
+ void insertExplicitReshardsOnDataFlowOp (ShardableDataFlowOpInterface& op,
706
+ IRRewriter& rewriter,
707
+ StringRef meshName) {
708
+ for (Value owner : llvm::concat<Value>(op.getOpResultEdgeOwners (),
709
+ op.getBlockArgumentEdgeOwners ())) {
710
+ TensorShardingAttr ownerSharding = op.transformTargetSharding (
711
+ owner, op.getEdgeOwnerSharding (owner),
712
+ DataFlowShardingTransformType::kBeforeEdgePropagation );
713
+ for (OpOperand* sourceOpOperand : op.getEdgeSources (owner)) {
714
+ Value source = sourceOpOperand->get ();
715
+ TensorShardingAttr sourceSharding =
716
+ getOrCreateSharding (source, meshName, /* closedIfMissing=*/ true );
717
+ if (shouldReshard (sourceSharding, ownerSharding)) {
718
+ rewriter.setInsertionPointAfterValue (source);
719
+ auto reshardOp = rewriter.create <ReshardOp>(
720
+ source.getLoc (), source,
721
+ ownerSharding
722
+ ? ownerSharding
723
+ : TensorShardingAttr::getFullyClosedLike (sourceSharding));
724
+ sourceOpOperand->set (reshardOp);
725
+ }
726
+ }
727
+ }
728
+ }
729
+
684
730
struct InsertExplicitReshardsPass
685
731
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
686
732
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
@@ -689,10 +735,12 @@ struct InsertExplicitReshardsPass
689
735
func::FuncOp funcOp = getOperation ();
690
736
IRRewriter rewriter (funcOp);
691
737
SymbolTable symbolTable (funcOp->getParentOfType <ModuleOp>());
692
- // TODO(enver): Handle data flow ops.
738
+
693
739
funcOp.walk ([&](Operation* op) {
694
- // TODO(enver): Check if data flow ops, data flow edge op, manual
695
- // computation op require extra check before creating sharding rule.
740
+ if (isa<func::ReturnOp>(op)) {
741
+ insertExplicitReshardsOnFuncReturn (op, funcOp, rewriter);
742
+ return ;
743
+ }
696
744
697
745
std::optional<StringRef> meshName =
698
746
getCommonMeshName (getShardings (op->getOperands ()),
@@ -706,56 +754,12 @@ struct InsertExplicitReshardsPass
706
754
return ;
707
755
}
708
756
709
- if (isa<func::ReturnOp>(op)) {
710
- rewriter.setInsertionPoint (op);
711
- for (const auto & [index , opOperand] :
712
- llvm::enumerate (op->getOpOperands ())) {
713
- Value operand = opOperand.get ();
714
- TensorShardingAttr funcResultSharding =
715
- getFuncResultSharding (funcOp, index );
716
- TensorShardingAttr operandSharding =
717
- getOrCreateSharding (operand, *meshName, /* closedIfMissing=*/ true );
718
- if (shouldReshard (operandSharding, funcResultSharding)) {
719
- // TODO(enver): Close all shardings and drop replicated axes before
720
- // this pass on the export pipeline.
721
- auto reshardOp = rewriter.create <ReshardOp>(
722
- operand.getLoc (), operand,
723
- funcResultSharding
724
- ? funcResultSharding
725
- : TensorShardingAttr::getFullyClosedLike (operandSharding));
726
- opOperand.set (reshardOp);
727
- }
728
- }
729
- return ;
730
- }
731
-
732
757
// TODO(enver): Prefer resharding the owner when multiple sources are
733
758
// sharded in the same way.
734
759
if (auto shardableDataFlowOp =
735
760
dyn_cast<ShardableDataFlowOpInterface>(op)) {
736
- for (Value owner : llvm::concat<Value>(
737
- shardableDataFlowOp.getOpResultEdgeOwners (),
738
- shardableDataFlowOp.getBlockArgumentEdgeOwners ())) {
739
- TensorShardingAttr ownerSharding =
740
- shardableDataFlowOp.transformTargetSharding (
741
- owner, shardableDataFlowOp.getEdgeOwnerSharding (owner),
742
- DataFlowShardingTransformType::kBeforeEdgePropagation );
743
- for (OpOperand* sourceOpOperand :
744
- shardableDataFlowOp.getEdgeSources (owner)) {
745
- Value source = sourceOpOperand->get ();
746
- TensorShardingAttr sourceSharding = getOrCreateSharding (
747
- source, *meshName, /* closedIfMissing=*/ true );
748
- if (shouldReshard (sourceSharding, ownerSharding)) {
749
- rewriter.setInsertionPointAfterValue (source);
750
- auto reshardOp = rewriter.create <ReshardOp>(
751
- source.getLoc (), source,
752
- ownerSharding
753
- ? ownerSharding
754
- : TensorShardingAttr::getFullyClosedLike (sourceSharding));
755
- sourceOpOperand->set (reshardOp);
756
- }
757
- }
758
- }
761
+ insertExplicitReshardsOnDataFlowOp (shardableDataFlowOp, rewriter,
762
+ *meshName);
759
763
return ;
760
764
}
761
765
0 commit comments