Skip to content

Commit 8432f8c

Browse files
Reshard func results and data flow ops also when corresponding return operand and func result have different meshes.
Unless both are fully replicated. PiperOrigin-RevId: 737622545
1 parent a8d83eb commit 8432f8c

File tree

2 files changed

+79
-25
lines changed

2 files changed

+79
-25
lines changed

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

+22-25
Original file line numberDiff line numberDiff line change
@@ -712,11 +712,9 @@ bool shouldReshard(TensorShardingAttr sourceSharding,
712712
void insertExplicitReshardsToTargetSharding(OpOperand* opOperand,
713713
TensorShardingAttr targetSharding,
714714
IRRewriter& rewriter,
715-
StringRef meshName,
716715
const bool insertAfterOperand) {
717716
Value operand = opOperand->get();
718-
TensorShardingAttr operandSharding =
719-
getOrCreateSharding(operand, meshName, /*closedIfMissing=*/true);
717+
TensorShardingAttr operandSharding = getSharding(operand);
720718
if (shouldReshard(operandSharding, targetSharding)) {
721719
if (insertAfterOperand) {
722720
rewriter.setInsertionPointAfterValue(operand);
@@ -725,26 +723,26 @@ void insertExplicitReshardsToTargetSharding(OpOperand* opOperand,
725723
operand.getLoc(), operand,
726724
targetSharding
727725
? targetSharding
726+
// Since it should reshard and `targetSharding` is empty,
727+
// `operandSharding` is guaranteed to be nonempty.
728728
: TensorShardingAttr::getFullyClosedLike(operandSharding));
729729
opOperand->set(reshardOp);
730730
}
731731
}
732732

733733
void insertExplicitReshardsOnFuncReturn(Operation* op, func::FuncOp& funcOp,
734-
IRRewriter& rewriter,
735-
StringRef meshName) {
734+
IRRewriter& rewriter) {
736735
rewriter.setInsertionPoint(op);
737736
for (const auto& [index, opOperand] : llvm::enumerate(op->getOpOperands())) {
738737
insertExplicitReshardsToTargetSharding(
739738
/*opOperand=*/&opOperand,
740739
/*targetSharding=*/getFuncResultSharding(funcOp, index), rewriter,
741-
meshName, /*insertAfterOperand=*/false);
740+
/*insertAfterOperand=*/false);
742741
}
743742
}
744743

745744
void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
746-
IRRewriter& rewriter,
747-
StringRef meshName) {
745+
IRRewriter& rewriter) {
748746
for (Value owner : llvm::concat<Value>(op.getOpResultEdgeOwners(),
749747
op.getBlockArgumentEdgeOwners())) {
750748
TensorShardingAttr ownerSharding = op.transformTargetSharding(
@@ -753,7 +751,7 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
753751
for (OpOperand* sourceOpOperand : op.getEdgeSources(owner)) {
754752
insertExplicitReshardsToTargetSharding(
755753
/*opOperand=*/sourceOpOperand,
756-
/*targetSharding=*/ownerSharding, rewriter, meshName,
754+
/*targetSharding=*/ownerSharding, rewriter,
757755
/*insertAfterOperand=*/true);
758756
}
759757
}
@@ -769,6 +767,21 @@ struct InsertExplicitReshardsPass
769767
SymbolTable symbolTable(funcOp->getParentOfType<ModuleOp>());
770768

771769
funcOp.walk([&](Operation* op) {
770+
// TODO(enver): Does not need to be part of the walk on the func, instead
771+
// get the terminatior with getBodyTerminator.
772+
if (isa<func::ReturnOp>(op)) {
773+
insertExplicitReshardsOnFuncReturn(op, funcOp, rewriter);
774+
return;
775+
}
776+
777+
// TODO(enver): Prefer resharding the owner when multiple sources are
778+
// sharded in the same way.
779+
if (auto shardableDataFlowOp =
780+
dyn_cast<ShardableDataFlowOpInterface>(op)) {
781+
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter);
782+
return;
783+
}
784+
772785
SmallVector<TensorShardingAttr> inShardings =
773786
getShardings(op->getOperands());
774787
SmallVector<TensorShardingAttr> outShardings =
@@ -784,22 +797,6 @@ struct InsertExplicitReshardsPass
784797
return;
785798
}
786799

787-
// TODO(enver): Does not need to be part of the walk on the func, instead
788-
// get the terminatior with getBodyTerminator.
789-
if (isa<func::ReturnOp>(op)) {
790-
insertExplicitReshardsOnFuncReturn(op, funcOp, rewriter, *meshName);
791-
return;
792-
}
793-
794-
// TODO(enver): Prefer resharding the owner when multiple sources are
795-
// sharded in the same way.
796-
if (auto shardableDataFlowOp =
797-
dyn_cast<ShardableDataFlowOpInterface>(op)) {
798-
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
799-
*meshName);
800-
return;
801-
}
802-
803800
// NOTE: Creating a sharding rule requires data flow edges are present.
804801
OpShardingRuleAttr shardingRule =
805802
getOrCreateShardingRule(op, /*conservativePropagation=*/false,

shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir

+57
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
// RUN: sdy_opt %s -sdy-insert-explicit-reshards | FileCheck %s
22

33
sdy.mesh @mesh = <["x"=4, "y"=2]>
4+
sdy.mesh @mesh_xt = <["x"=2, "t"=4]>
45
sdy.mesh @mesh_xyz = <["x"=4, "y"=2, "z"=4]>
56
sdy.mesh @mesh_xyzt = <["x"=4, "y"=4, "z"=4, "t"=8]>
7+
sdy.mesh @mesh_iota = <["x"=2, "y"=2]>
8+
sdy.mesh @mesh_non_iota = <["x"=2, "y"=2], device_ids=[3, 2, 1, 0]>
69

710
// CHECK-LABEL: func @funcop_result_sharding_does_not_match
811
func.func @funcop_result_sharding_does_not_match(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
@@ -11,6 +14,41 @@ func.func @funcop_result_sharding_does_not_match(%arg0: tensor<8x16xf32> {sdy.sh
1114
return %arg0 : tensor<8x16xf32>
1215
}
1316

17+
// CHECK-LABEL: func @funcop_result_unsharded_but_different_meshes_between_return_and_func_result
18+
func.func @funcop_result_unsharded_but_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{}, {}]>}) {
19+
// CHECK-NOT: sdy.reshard
20+
return %arg0 : tensor<8x16xf32>
21+
}
22+
23+
// CHECK-LABEL: func @funcop_result_sharding_matches_but_different_meshes_between_return_and_func_result
24+
func.func @funcop_result_sharding_matches_but_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}, {}]>}) {
25+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {}]> : tensor<8x16xf32>
26+
// CHECK: return %[[RESHARD]] : tensor<8x16xf32>
27+
return %arg0 : tensor<8x16xf32>
28+
}
29+
30+
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result
31+
func.func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{}, {"t"}]>}) {
32+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{}, {"t"}]> : tensor<8x16xf32>
33+
// CHECK: return %[[RESHARD]] : tensor<8x16xf32>
34+
return %arg0 : tensor<8x16xf32>
35+
}
36+
37+
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result_multiple_results
38+
func.func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result_multiple_results(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"t"}, {}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}, {}]>}, tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
39+
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {}]> : tensor<8x32xf32>
40+
// CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{}, {"x"}]> : tensor<32x16xf32>
41+
// CHECK: return %[[RESHARD1]], %[[RESHARD2]] : tensor<8x32xf32>, tensor<32x16xf32>
42+
return %arg0, %arg1 : tensor<8x32xf32>, tensor<32x16xf32>
43+
}
44+
45+
// CHECK-LABEL: func @funcop_result_identical_sharding_but_different_meshes_between_return_and_func_result
46+
func.func @funcop_result_identical_sharding_but_different_meshes_between_return_and_func_result(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)2}, {"y"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}, {"t":(2)2}]>}) {
47+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {"t":(2)2}]> : tensor<8x16xf32>
48+
// CHECK: return %[[RESHARD]] : tensor<8x16xf32>
49+
return %arg0 : tensor<8x16xf32>
50+
}
51+
1452
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_funcop_result_empty
1553
func.func @funcop_result_sharding_does_not_match_funcop_result_empty(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> tensor<8x16xf32> {
1654
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> : tensor<8x16xf32>
@@ -1628,3 +1666,22 @@ func.func @optimization_barrier(%arg0: tensor<210xf32> {sdy.sharding = #sdy.shar
16281666
%2 = stablehlo.negate %1 {sdy.sharding= #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
16291667
return %2 : tensor<210xf32>
16301668
}
1669+
1670+
// CHECK-LABEL: func @optimization_barrier_different_meshes
1671+
func.func @optimization_barrier_different_meshes(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_xt, [{"x"}]>}) {
1672+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}]> : tensor<210xf32>
1673+
// CHECK-NEXT: stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xt, [{"x"}]>]>} %[[RESHARD]]
1674+
%1 = stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xt, [{"x"}]>]>} %arg0 : tensor<210xf32>
1675+
%2 = stablehlo.negate %1 {sdy.sharding= #sdy.sharding_per_value<[<@mesh_xt, [{"x"}]>]>} : tensor<210xf32>
1676+
return %2 : tensor<210xf32>
1677+
}
1678+
1679+
// CHECK-LABEL: func @optimization_barrier_meshes_different_device_order
1680+
func.func @optimization_barrier_meshes_different_device_order(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_iota, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_non_iota, [{"x"}]>}) {
1681+
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_non_iota, [{"x"}]> : tensor<210xf32>
1682+
// CHECK-NEXT: stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value<[<@mesh_non_iota, [{"x"}]>]>} %[[RESHARD]]
1683+
%1 = stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value<[<@mesh_non_iota, [{"x"}]>]>} %arg0 : tensor<210xf32>
1684+
%2 = stablehlo.negate %1 {sdy.sharding= #sdy.sharding_per_value<[<@mesh_non_iota, [{"x"}]>]>} : tensor<210xf32>
1685+
return %2 : tensor<210xf32>
1686+
}
1687+

0 commit comments

Comments
 (0)