Skip to content

Commit c72dfdb

Browse files
Reshard func returns also when corresponding return operand and func result have different meshes.
Unless both are fully replicated. PiperOrigin-RevId: 737622545
1 parent 6ccfb39 commit c72dfdb

File tree

2 files changed

+89
-49
lines changed

2 files changed

+89
-49
lines changed

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

+53-49
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,52 @@ bool shouldReshard(TensorShardingAttr sourceSharding,
681681
return sourceSharding != targetSharding;
682682
}
683683

684+
void insertExplicitReshardsOnFuncReturn(Operation* op, func::FuncOp& funcOp,
685+
IRRewriter& rewriter) {
686+
rewriter.setInsertionPoint(op);
687+
for (const auto& [index, opOperand] : llvm::enumerate(op->getOpOperands())) {
688+
Value operand = opOperand.get();
689+
TensorShardingAttr funcResultSharding =
690+
getFuncResultSharding(funcOp, index);
691+
TensorShardingAttr operandSharding = getSharding(operand);
692+
if (shouldReshard(operandSharding, funcResultSharding)) {
693+
auto reshardOp = rewriter.create<ReshardOp>(
694+
operand.getLoc(), operand,
695+
funcResultSharding
696+
? funcResultSharding
697+
// Since it should reshard and `funcResultSharding` is empty,
698+
// `operandSharding` is guaranteed to be nonempty.
699+
: TensorShardingAttr::getFullyClosedLike(operandSharding));
700+
opOperand.set(reshardOp);
701+
}
702+
}
703+
}
704+
705+
void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
706+
IRRewriter& rewriter,
707+
StringRef meshName) {
708+
for (Value owner : llvm::concat<Value>(op.getOpResultEdgeOwners(),
709+
op.getBlockArgumentEdgeOwners())) {
710+
TensorShardingAttr ownerSharding = op.transformTargetSharding(
711+
owner, op.getEdgeOwnerSharding(owner),
712+
DataFlowShardingTransformType::kBeforeEdgePropagation);
713+
for (OpOperand* sourceOpOperand : op.getEdgeSources(owner)) {
714+
Value source = sourceOpOperand->get();
715+
TensorShardingAttr sourceSharding =
716+
getOrCreateSharding(source, meshName, /*closedIfMissing=*/true);
717+
if (shouldReshard(sourceSharding, ownerSharding)) {
718+
rewriter.setInsertionPointAfterValue(source);
719+
auto reshardOp = rewriter.create<ReshardOp>(
720+
source.getLoc(), source,
721+
ownerSharding
722+
? ownerSharding
723+
: TensorShardingAttr::getFullyClosedLike(sourceSharding));
724+
sourceOpOperand->set(reshardOp);
725+
}
726+
}
727+
}
728+
}
729+
684730
struct InsertExplicitReshardsPass
685731
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
686732
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
@@ -689,10 +735,12 @@ struct InsertExplicitReshardsPass
689735
func::FuncOp funcOp = getOperation();
690736
IRRewriter rewriter(funcOp);
691737
SymbolTable symbolTable(funcOp->getParentOfType<ModuleOp>());
692-
// TODO(enver): Handle data flow ops.
738+
693739
funcOp.walk([&](Operation* op) {
694-
// TODO(enver): Check if data flow ops, data flow edge op, manual
695-
// computation op require extra check before creating sharding rule.
740+
if (isa<func::ReturnOp>(op)) {
741+
insertExplicitReshardsOnFuncReturn(op, funcOp, rewriter);
742+
return;
743+
}
696744

697745
std::optional<StringRef> meshName =
698746
getCommonMeshName(getShardings(op->getOperands()),
@@ -706,56 +754,12 @@ struct InsertExplicitReshardsPass
706754
return;
707755
}
708756

709-
if (isa<func::ReturnOp>(op)) {
710-
rewriter.setInsertionPoint(op);
711-
for (const auto& [index, opOperand] :
712-
llvm::enumerate(op->getOpOperands())) {
713-
Value operand = opOperand.get();
714-
TensorShardingAttr funcResultSharding =
715-
getFuncResultSharding(funcOp, index);
716-
TensorShardingAttr operandSharding =
717-
getOrCreateSharding(operand, *meshName, /*closedIfMissing=*/true);
718-
if (shouldReshard(operandSharding, funcResultSharding)) {
719-
// TODO(enver): Close all shardings and drop replicated axes before
720-
// this pass on the export pipeline.
721-
auto reshardOp = rewriter.create<ReshardOp>(
722-
operand.getLoc(), operand,
723-
funcResultSharding
724-
? funcResultSharding
725-
: TensorShardingAttr::getFullyClosedLike(operandSharding));
726-
opOperand.set(reshardOp);
727-
}
728-
}
729-
return;
730-
}
731-
732757
// TODO(enver): Prefer resharding the owner when multiple sources are
733758
// sharded in the same way.
734759
if (auto shardableDataFlowOp =
735760
dyn_cast<ShardableDataFlowOpInterface>(op)) {
736-
for (Value owner : llvm::concat<Value>(
737-
shardableDataFlowOp.getOpResultEdgeOwners(),
738-
shardableDataFlowOp.getBlockArgumentEdgeOwners())) {
739-
TensorShardingAttr ownerSharding =
740-
shardableDataFlowOp.transformTargetSharding(
741-
owner, shardableDataFlowOp.getEdgeOwnerSharding(owner),
742-
DataFlowShardingTransformType::kBeforeEdgePropagation);
743-
for (OpOperand* sourceOpOperand :
744-
shardableDataFlowOp.getEdgeSources(owner)) {
745-
Value source = sourceOpOperand->get();
746-
TensorShardingAttr sourceSharding = getOrCreateSharding(
747-
source, *meshName, /*closedIfMissing=*/true);
748-
if (shouldReshard(sourceSharding, ownerSharding)) {
749-
rewriter.setInsertionPointAfterValue(source);
750-
auto reshardOp = rewriter.create<ReshardOp>(
751-
source.getLoc(), source,
752-
ownerSharding
753-
? ownerSharding
754-
: TensorShardingAttr::getFullyClosedLike(sourceSharding));
755-
sourceOpOperand->set(reshardOp);
756-
}
757-
}
758-
}
761+
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
762+
*meshName);
759763
return;
760764
}
761765

shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir

+36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: sdy_opt %s -sdy-insert-explicit-reshards | FileCheck %s
22

33
sdy.mesh @mesh = <["x"=4, "y"=2]>
4+
sdy.mesh @mesh_xt = <["x"=2, "t"=4]>
45
sdy.mesh @mesh_xyz = <["x"=4, "y"=2, "z"=4]>
56
sdy.mesh @mesh_xyzt = <["x"=4, "y"=4, "z"=4, "t"=8]>
67

@@ -19,6 +20,41 @@ func.func @funcop_result_sharding_does_not_match(%arg0: tensor<8x16xf32> {sdy.sh
1920
return %arg0 : tensor<8x16xf32>
2021
}
2122

23+
// CHECK-LABEL: func @funcop_result_unsharded_but_different_meshes_between_return_and_func_result
24+
func.func @funcop_result_unsharded_but_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{}, {}]>}) {
25+
// CHECK-NOT: sdy.reshard
26+
return %arg0 : tensor<8x16xf32>
27+
}
28+
29+
// CHECK-LABEL: func @funcop_result_sharding_matches_but_different_meshes_between_return_and_func_result
30+
func.func @funcop_result_sharding_matches_but_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}, {}]>}) {
31+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {}]> : tensor<8x16xf32>
32+
// CHECK: return %[[RESHARD]] : tensor<8x16xf32>
33+
return %arg0 : tensor<8x16xf32>
34+
}
35+
36+
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result
37+
func.func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{}, {"t"}]>}) {
38+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{}, {"t"}]> : tensor<8x16xf32>
39+
// CHECK: return %[[RESHARD]] : tensor<8x16xf32>
40+
return %arg0 : tensor<8x16xf32>
41+
}
42+
43+
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result_multiple_results
44+
func.func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result_multiple_results(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"t"}, {}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}, {}]>}, tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
45+
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {}]> : tensor<8x32xf32>
46+
// CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{}, {"x"}]> : tensor<32x16xf32>
47+
// CHECK: return %[[RESHARD1]], %[[RESHARD2]] : tensor<8x32xf32>, tensor<32x16xf32>
48+
return %arg0, %arg1 : tensor<8x32xf32>, tensor<32x16xf32>
49+
}
50+
51+
// CHECK-LABEL: func @funcop_result_identical_sharding_but_different_meshes_between_return_and_func_result
52+
func.func @funcop_result_identical_sharding_but_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)2}, {"y"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}, {"t":(2)2}]>}) {
53+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {"t":(2)2}]> : tensor<8x16xf32>
54+
// CHECK: return %[[RESHARD]] : tensor<8x16xf32>
55+
return %arg0 : tensor<8x16xf32>
56+
}
57+
2258
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_funcop_result_empty
2359
func.func @funcop_result_sharding_does_not_match_funcop_result_empty(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> tensor<8x16xf32> {
2460
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> : tensor<8x16xf32>

0 commit comments

Comments
 (0)