1
1
// RUN: sdy_opt %s -sdy-insert-explicit-reshards | FileCheck %s
2
2
3
3
sdy.mesh @mesh = <[" x" =4 , " y" =2 ]>
4
+ sdy.mesh @mesh_xt = <[" x" =2 , " t" =4 ]>
4
5
sdy.mesh @mesh_xyz = <[" x" =4 , " y" =2 , " z" =4 ]>
5
6
sdy.mesh @mesh_xyzt = <[" x" =4 , " y" =4 , " z" =4 , " t" =8 ]>
7
+ sdy.mesh @mesh_iota = <[" x" =2 , " y" =2 ]>
8
+ sdy.mesh @mesh_non_iota = <[" x" =2 , " y" =2 ], device_ids =[3 , 2 , 1 , 0 ]>
6
9
7
10
// CHECK-LABEL: func @funcop_result_sharding_does_not_match
8
11
func.func @funcop_result_sharding_does_not_match (%arg0: tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }, {}]>}) -> (tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{}, {" x" }]>}) {
@@ -11,6 +14,41 @@ func.func @funcop_result_sharding_does_not_match(%arg0: tensor<8x16xf32> {sdy.sh
11
14
return %arg0 : tensor <8 x16 xf32 >
12
15
}
13
16
17
+ // CHECK-LABEL: func @funcop_result_unsharded_but_different_meshes_between_return_and_func_result
18
+ func.func @funcop_result_unsharded_but_different_meshes_between_return_and_func_result (%arg0: tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{}, {}]>}) -> (tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{}, {}]>}) {
19
+ // CHECK-NOT: sdy.reshard
20
+ return %arg0 : tensor <8 x16 xf32 >
21
+ }
22
+
23
+ // CHECK-LABEL: func @funcop_result_sharding_matches_but_different_meshes_between_return_and_func_result
24
+ func.func @funcop_result_sharding_matches_but_different_meshes_between_return_and_func_result (%arg0: tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }, {}]>}) -> (tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{" x" }, {}]>}) {
25
+ // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {}]> : tensor<8x16xf32>
26
+ // CHECK: return %[[RESHARD]] : tensor<8x16xf32>
27
+ return %arg0 : tensor <8 x16 xf32 >
28
+ }
29
+
30
+ // CHECK-LABEL: func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result
31
+ func.func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result (%arg0: tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }, {}]>}) -> (tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{}, {" t" }]>}) {
32
+ // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{}, {"t"}]> : tensor<8x16xf32>
33
+ // CHECK: return %[[RESHARD]] : tensor<8x16xf32>
34
+ return %arg0 : tensor <8 x16 xf32 >
35
+ }
36
+
37
+ // CHECK-LABEL: func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result_multiple_results
38
+ func.func @funcop_result_sharding_does_not_match_different_meshes_between_return_and_func_result_multiple_results (%arg0: tensor <8 x32 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{}, {" y" }]>}, %arg1: tensor <32 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{" t" }, {}]>}) -> (tensor <8 x32 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{" x" }, {}]>}, tensor <32 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{}, {" x" }]>}) {
39
+ // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {}]> : tensor<8x32xf32>
40
+ // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{}, {"x"}]> : tensor<32x16xf32>
41
+ // CHECK: return %[[RESHARD1]], %[[RESHARD2]] : tensor<8x32xf32>, tensor<32x16xf32>
42
+ return %arg0 , %arg1 : tensor <8 x32 xf32 >, tensor <32 x16 xf32 >
43
+ }
44
+
45
+ // CHECK-LABEL: func @funcop_result_identical_sharding_but_different_meshes_between_return_and_func_result
46
+ func.func @funcop_result_identical_sharding_but_different_meshes_between_return_and_func_result (%arg0: tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" :(1 )2 }, {" y" }]>}) -> (tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{" x" }, {" t" :(2 )2 }]>}) {
47
+ // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}, {"t":(2)2}]> : tensor<8x16xf32>
48
+ // CHECK: return %[[RESHARD]] : tensor<8x16xf32>
49
+ return %arg0 : tensor <8 x16 xf32 >
50
+ }
51
+
14
52
// CHECK-LABEL: func @funcop_result_sharding_does_not_match_funcop_result_empty
15
53
func.func @funcop_result_sharding_does_not_match_funcop_result_empty (%arg0: tensor <8 x16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }, {}]>}) -> tensor <8 x16 xf32 > {
16
54
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> : tensor<8x16xf32>
@@ -1628,3 +1666,22 @@ func.func @optimization_barrier(%arg0: tensor<210xf32> {sdy.sharding = #sdy.shar
1628
1666
%2 = stablehlo.negate %1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
1629
1667
return %2 : tensor <210 xf32 >
1630
1668
}
1669
+
1670
+ // CHECK-LABEL: func @optimization_barrier_different_meshes
1671
+ func.func @optimization_barrier_different_meshes (%arg0: tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }]>}) -> (tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh_xt , [{" x" }]>}) {
1672
+ // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_xt, [{"x"}]> : tensor<210xf32>
1673
+ // CHECK-NEXT: stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xt, [{"x"}]>]>} %[[RESHARD]]
1674
+ %1 = stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value <[<@mesh_xt , [{" x" }]>]>} %arg0 : tensor <210 xf32 >
1675
+ %2 = stablehlo.negate %1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh_xt , [{" x" }]>]>} : tensor <210 xf32 >
1676
+ return %2 : tensor <210 xf32 >
1677
+ }
1678
+
1679
+ // CHECK-LABEL: func @optimization_barrier_meshes_different_device_order
1680
+ func.func @optimization_barrier_meshes_different_device_order (%arg0: tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh_iota , [{" x" }]>}) -> (tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh_non_iota , [{" x" }]>}) {
1681
+ // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh_non_iota, [{"x"}]> : tensor<210xf32>
1682
+ // CHECK-NEXT: stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value<[<@mesh_non_iota, [{"x"}]>]>} %[[RESHARD]]
1683
+ %1 = stablehlo.optimization_barrier {sdy.sharding = #sdy.sharding_per_value <[<@mesh_non_iota , [{" x" }]>]>} %arg0 : tensor <210 xf32 >
1684
+ %2 = stablehlo.negate %1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh_non_iota , [{" x" }]>]>} : tensor <210 xf32 >
1685
+ return %2 : tensor <210 xf32 >
1686
+ }
1687
+
0 commit comments