Skip to content

Commit 15dabf5

Browse files
bartchr808copybara-github
authored andcommitted
#sdy add temporary pass in XLA SDY import to not allow meshes with different axis shapes.
Longer term we should either error during propagation or just support propagating through different meshes, but until this is done www.github.com/jax-ml/jax/issues/26914 we can't look at doing this. As JAX first wants to make sure the lowered module never even has different meshes (but this is just a restriction JAX is imposing, that we don't have to impose in Shardy). PiperOrigin-RevId: 735710081
1 parent e670205 commit 15dabf5

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

shardy/round_trip_import/import_shardy_attrs.cc

+30
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <cassert>
1919
#include <cstdint>
20+
#include <iterator>
2021
#include <memory>
2122
#include <optional>
2223

@@ -143,6 +144,26 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
143144
});
144145
}
145146

147+
LogicalResult hasDifferentMeshShapes(
148+
SmallVector<int64_t>& meshesWithAxisSizes, MeshAttr meshAttr,
149+
ModuleOp moduleOp) {
150+
ArrayRef<MeshAxisAttr> axes = meshAttr.getAxes();
151+
SmallVector<int64_t> sizes;
152+
if (!axes.empty()) {
153+
sizes.reserve(axes.size());
154+
llvm::transform(axes, std::back_inserter(sizes),
155+
[](MeshAxisAttr axis) { return axis.getSize(); });
156+
if (meshesWithAxisSizes.empty()) {
157+
meshesWithAxisSizes = sizes;
158+
} else if (meshesWithAxisSizes != sizes) {
159+
return moduleOp.emitError(
160+
"JAX does not support multiple meshes with different "
161+
"axis sizes.");
162+
}
163+
}
164+
return mlir::success();
165+
}
166+
146167
class SdyRoundTripImportShardyAttrsPass
147168
: public PassWrapper<SdyRoundTripImportShardyAttrsPass,
148169
OperationPass<ModuleOp>> {
@@ -167,8 +188,17 @@ class SdyRoundTripImportShardyAttrsPass
167188
// Insert the meshes before any functions.
168189
rewriter.setInsertionPointToStart(moduleOp.getBody());
169190
SymbolTable symbolTable(moduleOp);
191+
// TODO(b/402371282): allow different meshes with different axis sizes
192+
// during import. Either support propagation through different mesh shapes
193+
// or error during propagation.
194+
SmallVector<int64_t> meshesWithAxisSizes;
170195
for (NamedAttribute mesh : sdyMeshes) {
171196
auto meshAttr = cast<MeshAttr>(mesh.getValue());
197+
if (hasDifferentMeshShapes(meshesWithAxisSizes, meshAttr, moduleOp)
198+
.failed()) {
199+
signalPassFailure();
200+
return;
201+
}
172202
symbolTable.insert(
173203
rewriter.create<MeshOp>(moduleOp.getLoc(), mesh.getName(), meshAttr));
174204
}

shardy/round_trip_import/test/sdy_round_trip_import_pipeline.mlir

+19-19
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,34 @@
22

33
// CHECK-LABEL: module @multiple_func_result_shardings
44
module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes =
5-
"{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>, mesh2 = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=1]>}"}} {
6-
// CHECK: sdy.mesh @mesh = <["a"=8, "b"=8, "c"=8]>
5+
"{mesh = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=4]>, mesh2 = #sdy.mesh<[\\\22x\\\22=1, \\\22y\\\22=4, \\\22z\\\22=4]>}"}} {
6+
// CHECK: sdy.mesh @mesh = <["a"=1, "b"=4, "c"=4]>
77

8-
// CHECK: sdy.mesh @mesh2 = <["a"=1, "b"=4, "c"=1]>
8+
// CHECK: sdy.mesh @mesh2 = <["x"=1, "y"=4, "z"=4]>
99

1010
// CHECK-LABEL: func @func_results_with_sharding
1111
// CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
12-
// CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>},
12+
// CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>},
1313
// CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}
1414
// CHECK-SAME: ) -> (
15-
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>},
15+
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p0]>},
1616
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
17-
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>},
17+
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p1]>},
1818
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>},
1919
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
20-
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) {
20+
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p3]>}) {
2121
// CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2
2222
// CHECK-NEXT: }
2323
func.func @func_results_with_sharding(
2424
%arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p2]>"}},
25-
%arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}},
25+
%arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}]>"}},
2626
%arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}}
2727
) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) {
28-
%0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
28+
%0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
2929
%1 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
30-
%2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
30+
%2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
3131
%3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
32-
%4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
32+
%4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
3333
return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>
3434
}
3535

@@ -83,19 +83,19 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x
8383
}
8484

8585
// CHECK-LABEL: func @discard_shardings_on_unknown_ops(
86-
// CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>})
87-
// CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p4]>}) {
86+
// CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p0]>})
87+
// CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p4]>}) {
8888
func.func @discard_shardings_on_unknown_ops(
89-
%arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p0]>"}}
89+
%arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p0]>"}}
9090
) -> tensor<32xi32> {
9191
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<32xi32>
92-
// CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a"}p2]> : tensor<32xi32>
92+
// CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"c"}p2]> : tensor<32xi32>
9393
// CHECK-NEXT: %[[UNKNOWN:.*]] = stablehlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32>
9494
// CHECK-NEXT: return %[[UNKNOWN]]
95-
%0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32>
96-
%1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
97-
%2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
98-
%3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
95+
%0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p1]>]>"}} : tensor<32xi32>
96+
%1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
97+
%2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
98+
%3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
9999
return %3 : tensor<32xi32>
100100
}
101101

0 commit comments

Comments
 (0)