|
2 | 2 |
|
3 | 3 | // CHECK-LABEL: module @multiple_func_result_shardings
|
4 | 4 | module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes =
|
5 |
| - "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>, mesh2 = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=1]>}"}} { |
6 |
| - // CHECK: sdy.mesh @mesh = <["a"=8, "b"=8, "c"=8]> |
| 5 | + "{mesh = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=4]>, mesh2 = #sdy.mesh<[\\\22x\\\22=1, \\\22y\\\22=4, \\\22z\\\22=4]>}"}} { |
| 6 | + // CHECK: sdy.mesh @mesh = <["a"=1, "b"=4, "c"=4]> |
7 | 7 |
|
8 |
| - // CHECK: sdy.mesh @mesh2 = <["a"=1, "b"=4, "c"=1]> |
| 8 | + // CHECK: sdy.mesh @mesh2 = <["x"=1, "y"=4, "z"=4]> |
9 | 9 |
|
10 | 10 | // CHECK-LABEL: func @func_results_with_sharding
|
11 | 11 | // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
|
12 |
| - // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, |
| 12 | + // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, |
13 | 13 | // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}
|
14 | 14 | // CHECK-SAME: ) -> (
|
15 |
| - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}, |
| 15 | + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p0]>}, |
16 | 16 | // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
|
17 |
| - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, |
| 17 | + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p1]>}, |
18 | 18 | // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>},
|
19 | 19 | // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
|
20 |
| - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) { |
| 20 | + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p3]>}) { |
21 | 21 | // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2
|
22 | 22 | // CHECK-NEXT: }
|
23 | 23 | func.func @func_results_with_sharding(
|
24 | 24 | %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p2]>"}},
|
25 |
| - %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}}, |
| 25 | + %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}]>"}}, |
26 | 26 | %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}}
|
27 | 27 | ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) {
|
28 |
| - %0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
| 28 | + %0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
29 | 29 | %1 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
|
30 |
| - %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
| 30 | + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
31 | 31 | %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
|
32 |
| - %4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
| 32 | + %4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
33 | 33 | return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>
|
34 | 34 | }
|
35 | 35 |
|
@@ -83,19 +83,19 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x
|
83 | 83 | }
|
84 | 84 |
|
85 | 85 | // CHECK-LABEL: func @discard_shardings_on_unknown_ops(
|
86 |
| - // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}) |
87 |
| - // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p4]>}) { |
| 86 | + // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p0]>}) |
| 87 | + // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p4]>}) { |
88 | 88 | func.func @discard_shardings_on_unknown_ops(
|
89 |
| - %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p0]>"}} |
| 89 | + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p0]>"}} |
90 | 90 | ) -> tensor<32xi32> {
|
91 | 91 | // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<32xi32>
|
92 |
| - // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a"}p2]> : tensor<32xi32> |
| 92 | + // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"c"}p2]> : tensor<32xi32> |
93 | 93 | // CHECK-NEXT: %[[UNKNOWN:.*]] = stablehlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32>
|
94 | 94 | // CHECK-NEXT: return %[[UNKNOWN]]
|
95 |
| - %0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32> |
96 |
| - %1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
97 |
| - %2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
98 |
| - %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
| 95 | + %0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p1]>]>"}} : tensor<32xi32> |
| 96 | + %1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
| 97 | + %2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
| 98 | + %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> |
99 | 99 | return %3 : tensor<32xi32>
|
100 | 100 | }
|
101 | 101 |
|
|
0 commit comments