|
| 1 | +// Smoke test: |
| 2 | +// RUN: sdy_opt %s.bc | FileCheck %s |
| 3 | +// RUN: sdy_opt %s.bc | sdy_translate --serialize | sdy_opt | FileCheck %s |
| 4 | +// RUN: sdy_opt %s.bc | sdy_translate --serialize --strip-debuginfo | sdy_opt | FileCheck %s |
| 5 | +// RUN: sdy_translate --deserialize %s.bc | sdy_opt | FileCheck %s |
| 6 | +// |
| 7 | +// Backward compatibility test: |
| 8 | +// RUN: sdy_translate --serialize %s | sdy_opt > %t.0 |
| 9 | +// RUN: sdy_opt %s > %t.1 |
| 10 | +// RUN: diff %t.0 %t.1 |
| 11 | +// |
| 12 | +// Forward compatibility test: |
| 13 | +// RUN: sdy_translate %s --serialize -strip-debuginfo > %t.2 |
| 14 | +// RUN: diff %s.bc %t.2 |
| 15 | + |
| 16 | +// CHECK: sdy.mesh @empty_mesh = <[]> |
| 17 | +sdy.mesh @empty_mesh = <[]> |
| 18 | + |
| 19 | +// CHECK: sdy.mesh @maximal_mesh_1 = <[], device_ids=[0]> |
| 20 | +sdy.mesh @maximal_mesh_1 = <[], device_ids=[0]> |
| 21 | + |
| 22 | +// CHECK: sdy.mesh @maximal_mesh_2 = <[], device_ids=[3]> |
| 23 | +sdy.mesh @maximal_mesh_2 = <[], device_ids=[3]> |
| 24 | + |
| 25 | +// CHECK: sdy.mesh @mesh_xy = <["x"=2, "y"=4]> |
| 26 | +sdy.mesh @mesh_xy = <["x"=2, "y"=4]> |
| 27 | + |
| 28 | +// CHECK: sdy.mesh @mesh_x_non_iota_device_ids = <["x"=4], device_ids=[0, 3, 2, 1]> |
| 29 | +sdy.mesh @mesh_x_non_iota_device_ids = <["x"=4], device_ids=[0, 3, 2, 1]> |
| 30 | + |
| 31 | +// CHECK: sdy.mesh @mesh_xyz = <["x"=2, "y"=2, "z"=2]> |
| 32 | +sdy.mesh @mesh_xyz = <["x"=2, "y"=2, "z"=2]> |
| 33 | + |
| 34 | +// CHECK-LABEL: func @sharding_constraint |
| 35 | +func.func @sharding_constraint(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { |
| 36 | + // CHECK-NEXT: sdy.sharding_constraint %arg0 <@mesh_xy, [{}, {"x"}], replicated={"y"}> |
| 37 | + %0 = sdy.sharding_constraint %arg0 <@mesh_xy, [{}, {"x"}], replicated={"y"}> : tensor<16x8xf32> |
| 38 | + return %0 : tensor<16x8xf32> |
| 39 | +} |
| 40 | + |
| 41 | +// CHECK-LABEL: func @reshard |
| 42 | +func.func @reshard(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { |
| 43 | + // CHECK-NEXT: sdy.reshard %arg0 <@mesh_xy, [{}, {"y"}], replicated={"x"}> |
| 44 | + %0 = sdy.reshard %arg0 <@mesh_xy, [{}, {"y"}], replicated={"x"}> : tensor<16x8xf32> |
| 45 | + return %0 : tensor<16x8xf32> |
| 46 | +} |
| 47 | + |
| 48 | +// CHECK-LABEL: func @manual_computation |
| 49 | +func.func @manual_computation(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { |
| 50 | + // CHECK{LITERAL}: sdy.manual_computation(%arg0) in_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] out_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] manual_axes={"x"} (%arg1: tensor<8x32xf32>) { |
| 51 | + // CHECK-NEXT: sdy.return %arg1 : tensor<8x32xf32> |
| 52 | + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> |
| 53 | + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] out_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] manual_axes={"x"} (%arg1: tensor<8x32xf32>) { |
| 54 | + sdy.return %arg1 : tensor<8x32xf32> |
| 55 | + } : (tensor<16x32xf32>) -> tensor<16x32xf32> |
| 56 | + func.return %0: tensor<16x32xf32> |
| 57 | +} |
| 58 | + |
| 59 | +// CHECK-LABEL: func @sharding_group |
| 60 | +func.func @sharding_group(%arg0: tensor<8xf32>) { |
| 61 | + // CHECK sdy.sharding_group %arg0 group_id=21 type=AS : tensor<8xf32> |
| 62 | + sdy.sharding_group %arg0 group_id=21 : tensor<8xf32> |
| 63 | + func.return |
| 64 | +} |
| 65 | + |
| 66 | +// CHECK-LABEL: func @constant |
| 67 | +func.func @constant() { |
| 68 | + // CHECK-NEXT: sdy.constant dense<1.000000e+00> : tensor<8x16xf32> |
| 69 | + %0 = sdy.constant dense<1.000000e+00> : tensor<8x16xf32> |
| 70 | + func.return |
| 71 | +} |
| 72 | + |
| 73 | +// CHECK-LABEL: func @data_flow_edge |
| 74 | +func.func @data_flow_edge(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32>) |
| 75 | + -> (tensor<32x96xf32>, tensor<32x96xf32>) { |
| 76 | + // CHECK-NEXT: sdy.data_flow_edge %arg0 |
| 77 | + // CHECK-NEXT: sdy.data_flow_edge %arg1 sharding=<@mesh_x_non_iota_device_ids, [{"x"}, {?}]> |
| 78 | + %1 = sdy.data_flow_edge %arg0 : tensor<32x96xf32> |
| 79 | + %2 = sdy.data_flow_edge %arg1 sharding=<@mesh_x_non_iota_device_ids, [{"x"}, {?}]> : tensor<32x96xf32> |
| 80 | + return %1, %2 : tensor<32x96xf32>, tensor<32x96xf32> |
| 81 | +} |
| 82 | + |
| 83 | +// CHECK-LABEL: func @propagation_barrier |
| 84 | +func.func @propagation_barrier(%arg0 : tensor<8xf32>, %arg1: tensor<16x8xf32>, %arg2: tensor<8x16xf32>) |
| 85 | + -> (tensor<8xf32>, tensor<16x8xf32>, tensor<8x16xf32>) { |
| 86 | + // CHECK-NEXT: sdy.propagation_barrier %arg0 allowed_direction=NONE |
| 87 | + // CHECK-NEXT: sdy.propagation_barrier %arg1 allowed_direction=FORWARD |
| 88 | + // CHECK-NEXT: sdy.propagation_barrier %arg2 allowed_direction=BACKWARD |
| 89 | + %0 = sdy.propagation_barrier %arg0 allowed_direction=NONE : tensor<8xf32> |
| 90 | + %1 = sdy.propagation_barrier %arg1 allowed_direction=FORWARD : tensor<16x8xf32> |
| 91 | + %2 = sdy.propagation_barrier %arg2 allowed_direction=BACKWARD : tensor<8x16xf32> |
| 92 | + return %0, %1, %2 : tensor<8xf32>, tensor<16x8xf32>, tensor<8x16xf32> |
| 93 | +} |
| 94 | + |
| 95 | +// CHECK-LABEL: func @named_computation |
| 96 | +func.func @named_computation(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { |
| 97 | + // CHECK-NEXT: %0:2 = sdy.named_computation<"foo">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) { |
| 98 | + // CHECK-NEXT: sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32> |
| 99 | + // CHECK-NEXT: } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 100 | + %0:2 = sdy.named_computation<"foo">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) { |
| 101 | + sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32> |
| 102 | + } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 103 | + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 104 | +} |
| 105 | + |
| 106 | +// CHECK-LABEL: func @tensor_sharding |
| 107 | +func.func @tensor_sharding(%arg0 : tensor<8x8xf32>, %arg1 : tensor<8x8xf32>) -> (tensor<64xf32>, tensor<8x8xf32>) { |
| 108 | + // CHECK-NEXT: stablehlo.custom_call @bar(%arg0, %arg1) |
| 109 | + // CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@mesh_xy, [{"x", "y"}]>, <@mesh_xy, [{"x"}p0, {"y":(1)2}p123]>]> |
| 110 | + %0:2 = stablehlo.custom_call @bar(%arg0, %arg1) |
| 111 | + {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xy, [{"x", "y"}]>, <@mesh_xy, [{"x"}p0, {"y":(1)2}p123]>]>} |
| 112 | + : (tensor<8x8xf32>, tensor<8x8xf32>) -> (tensor<64xf32>, tensor<8x8xf32>) |
| 113 | + return %0#0, %0#1 : tensor<64xf32>, tensor<8x8xf32> |
| 114 | +} |
| 115 | + |
| 116 | +// CHECK-LABEL: func @tensor_sharding_on_parameter_result |
| 117 | +// CHECK-SAME{LITERAL}: (%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}p2]>}) -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) |
| 118 | +func.func @tensor_sharding_on_parameter_result(%arg0 : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}p2]>}) |
| 119 | + -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) { |
| 120 | + %0 = stablehlo.custom_call @foo(%arg0) : (tensor<8x8xf32>) -> (tensor<64xf32>) |
| 121 | + return %0 : tensor<64xf32> |
| 122 | +} |
| 123 | + |
| 124 | +// CHECK-LABEL: func @tensor_sharding_scalar |
| 125 | +// CHECK-SAME{LITERAL}: (%arg0: tensor<f32> {sdy.sharding = #sdy.sharding<@mesh_xy, []>}) -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) |
| 126 | +func.func @tensor_sharding_scalar(%arg0 : tensor<f32> {sdy.sharding = #sdy.sharding<@mesh_xy, []>}) |
| 127 | + -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) { |
| 128 | + %0 = stablehlo.custom_call @foo(%arg0) : (tensor<f32>) -> (tensor<64xf32>) |
| 129 | + return %0 : tensor<64xf32> |
| 130 | +} |
| 131 | + |
| 132 | +// CHECK-LABEL: func @tensor_sharding_dynamic_shape |
| 133 | +func.func @tensor_sharding_dynamic_shape(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>) { |
| 134 | + // CHECK-NEXT: stablehlo.custom_call @bar(%arg0) |
| 135 | + // CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {}], replicated={"z"}>]> |
| 136 | + %0 = stablehlo.custom_call @bar(%arg0) |
| 137 | + {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {}], replicated={"z"}>]>} |
| 138 | + : (tensor<?x?xf32>) -> (tensor<?x?xf32>) |
| 139 | + return %0 : tensor<?x?xf32> |
| 140 | +} |
| 141 | + |
| 142 | +// CHECK-LABEL: func @sharding_rule_scalar |
| 143 | +func.func @sharding_rule_scalar(%arg0: tensor<f32>) -> tensor<f32> { |
| 144 | + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([], [])->([]), custom>} |
| 145 | + %0 = stablehlo.custom_call @foo(%arg0, %arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([], [])->([]), custom>} : |
| 146 | + (tensor<f32>, tensor<f32>) -> tensor<f32> |
| 147 | + return %0 : tensor<f32> |
| 148 | +} |
| 149 | + |
| 150 | +// CHECK-LABEL: func @sharding_rule_tensor |
| 151 | +func.func @sharding_rule_tensor(%arg0: tensor<2x4xf32>) -> tensor<8xf32> { |
| 152 | + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>} |
| 153 | + %0 = stablehlo.reshape %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>} : (tensor<2x4xf32>) -> tensor<8xf32> |
| 154 | + return %0 : tensor<8xf32> |
| 155 | +} |
| 156 | + |
| 157 | +// CHECK-LABEL: func @sharding_rule_tensor_with_many_dimensions |
| 158 | +func.func @sharding_rule_tensor_with_many_dimensions(%arg0: tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xf32>) -> tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> { |
| 159 | + // CHECK: #sdy.op_sharding_rule<([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8, z_9, z_10]) |
| 160 | + // CHECK-SAME: ->([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8z_9z_10]) |
| 161 | + // CHECK-SAME: {i=2, j=2, k=2, l=2, m=2, n=2, o=2, p=2, q=2, r=2, s=2, t=2, u=2, v=2, w=2, x=2, y=2, z=2, z_1=2, z_2=2, z_3=2, z_4=2, z_5=2, z_6=2, z_7=2, z_8=2, z_9=2, z_10=2}>} : |
| 162 | + %0 = stablehlo.custom_call @foo(%arg0) |
| 163 | + {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8, z_9, z_10])->([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8z_9z_10]) {i=2, j=2, k=2, l=2, m=2, n=2, o=2, p=2, q=2, r=2, s=2, t=2, u=2, v=2, w=2, x=2, y=2, z=2, z_1=2, z_2=2, z_3=2, z_4=2, z_5=2, z_6=2, z_7=2, z_8=2, z_9=2, z_10=2}>} |
| 164 | + : (tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xf32>) -> tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> |
| 165 | + return %0 : tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> |
| 166 | +} |
| 167 | + |
| 168 | +// CHECK-LABEL: func @custom_sharding_rule_custom_call |
| 169 | +func.func @custom_sharding_rule_custom_call(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { |
| 170 | + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>} |
| 171 | + %0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>} : (tensor<16x32xf32>) -> tensor<16x32xf32> |
| 172 | + func.return %0: tensor<16x32xf32> |
| 173 | +} |
0 commit comments