Skip to content

Commit 50f6e89

Browse files
#sdy add bytecode serialization utility
PiperOrigin-RevId: 730837065
1 parent 1be3888 commit 50f6e89

File tree

6 files changed

+299
-0
lines changed

6 files changed

+299
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Lit tests for the SDY dialect.
2+
3+
load("//shardy:lit.bzl", "glob_lit_tests")
4+
5+
package(default_visibility = ["//visibility:public"])
6+
7+
filegroup(
8+
name = "test_data",
9+
testonly = True,
10+
data = [
11+
"//shardy/dialect/sdy/ir/compatibility_test:compatibility_test.mlir.bc",
12+
"//shardy/tools:sdy_opt",
13+
"//shardy/tools:sdy_translate",
14+
"@llvm-project//llvm:FileCheck",
15+
],
16+
)
17+
18+
glob_lit_tests(
19+
name = "compatibility_tests",
20+
data = [":test_data"],
21+
driver = "@llvm-project//mlir:run_lit.sh",
22+
test_file_exts = ["mlir"],
23+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Smoke test:
2+
// RUN: sdy_opt %s.bc | FileCheck %s
3+
// RUN: sdy_opt %s.bc | sdy_translate --serialize | sdy_opt | FileCheck %s
4+
// RUN: sdy_opt %s.bc | sdy_translate --serialize --strip-debuginfo | sdy_opt | FileCheck %s
5+
// RUN: sdy_translate --deserialize %s.bc | sdy_opt | FileCheck %s
6+
//
7+
// Backward compatibility test:
8+
// RUN: sdy_translate --serialize %s | sdy_opt > %t.0
9+
// RUN: sdy_opt %s > %t.1
10+
// RUN: diff %t.0 %t.1
11+
//
12+
// Forward compatibility test:
13+
// RUN: sdy_translate %s --serialize -strip-debuginfo > %t.2
14+
// RUN: diff %s.bc %t.2
15+
16+
// CHECK: sdy.mesh @empty_mesh = <[]>
17+
sdy.mesh @empty_mesh = <[]>
18+
19+
// CHECK: sdy.mesh @maximal_mesh_1 = <[], device_ids=[0]>
20+
sdy.mesh @maximal_mesh_1 = <[], device_ids=[0]>
21+
22+
// CHECK: sdy.mesh @maximal_mesh_2 = <[], device_ids=[3]>
23+
sdy.mesh @maximal_mesh_2 = <[], device_ids=[3]>
24+
25+
// CHECK: sdy.mesh @mesh_xy = <["x"=2, "y"=4]>
26+
sdy.mesh @mesh_xy = <["x"=2, "y"=4]>
27+
28+
// CHECK: sdy.mesh @mesh_x_non_iota_device_ids = <["x"=4], device_ids=[0, 3, 2, 1]>
29+
sdy.mesh @mesh_x_non_iota_device_ids = <["x"=4], device_ids=[0, 3, 2, 1]>
30+
31+
// CHECK: sdy.mesh @mesh_xyz = <["x"=2, "y"=2, "z"=2]>
32+
sdy.mesh @mesh_xyz = <["x"=2, "y"=2, "z"=2]>
33+
34+
// CHECK-LABEL: func @sharding_constraint
35+
func.func @sharding_constraint(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
36+
// CHECK-NEXT: sdy.sharding_constraint %arg0 <@mesh_xy, [{}, {"x"}], replicated={"y"}>
37+
%0 = sdy.sharding_constraint %arg0 <@mesh_xy, [{}, {"x"}], replicated={"y"}> : tensor<16x8xf32>
38+
return %0 : tensor<16x8xf32>
39+
}
40+
41+
// CHECK-LABEL: func @reshard
42+
func.func @reshard(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> {
43+
// CHECK-NEXT: sdy.reshard %arg0 <@mesh_xy, [{}, {"y"}], replicated={"x"}>
44+
%0 = sdy.reshard %arg0 <@mesh_xy, [{}, {"y"}], replicated={"x"}> : tensor<16x8xf32>
45+
return %0 : tensor<16x8xf32>
46+
}
47+
48+
// CHECK-LABEL: func @manual_computation
49+
func.func @manual_computation(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> {
50+
// CHECK{LITERAL}: sdy.manual_computation(%arg0) in_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] out_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] manual_axes={"x"} (%arg1: tensor<8x32xf32>) {
51+
// CHECK-NEXT: sdy.return %arg1 : tensor<8x32xf32>
52+
// CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32>
53+
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] out_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] manual_axes={"x"} (%arg1: tensor<8x32xf32>) {
54+
sdy.return %arg1 : tensor<8x32xf32>
55+
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
56+
func.return %0: tensor<16x32xf32>
57+
}
58+
59+
// CHECK-LABEL: func @sharding_group
60+
func.func @sharding_group(%arg0: tensor<8xf32>) {
61+
// CHECK sdy.sharding_group %arg0 group_id=21 type=AS : tensor<8xf32>
62+
sdy.sharding_group %arg0 group_id=21 : tensor<8xf32>
63+
func.return
64+
}
65+
66+
// CHECK-LABEL: func @constant
67+
func.func @constant() {
68+
// CHECK-NEXT: sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
69+
%0 = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
70+
func.return
71+
}
72+
73+
// CHECK-LABEL: func @data_flow_edge
74+
func.func @data_flow_edge(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32>)
75+
-> (tensor<32x96xf32>, tensor<32x96xf32>) {
76+
// CHECK-NEXT: sdy.data_flow_edge %arg0
77+
// CHECK-NEXT: sdy.data_flow_edge %arg1 sharding=<@mesh_x_non_iota_device_ids, [{"x"}, {?}]>
78+
%1 = sdy.data_flow_edge %arg0 : tensor<32x96xf32>
79+
%2 = sdy.data_flow_edge %arg1 sharding=<@mesh_x_non_iota_device_ids, [{"x"}, {?}]> : tensor<32x96xf32>
80+
return %1, %2 : tensor<32x96xf32>, tensor<32x96xf32>
81+
}
82+
83+
// CHECK-LABEL: func @propagation_barrier
84+
func.func @propagation_barrier(%arg0 : tensor<8xf32>, %arg1: tensor<16x8xf32>, %arg2: tensor<8x16xf32>)
85+
-> (tensor<8xf32>, tensor<16x8xf32>, tensor<8x16xf32>) {
86+
// CHECK-NEXT: sdy.propagation_barrier %arg0 allowed_direction=NONE
87+
// CHECK-NEXT: sdy.propagation_barrier %arg1 allowed_direction=FORWARD
88+
// CHECK-NEXT: sdy.propagation_barrier %arg2 allowed_direction=BACKWARD
89+
%0 = sdy.propagation_barrier %arg0 allowed_direction=NONE : tensor<8xf32>
90+
%1 = sdy.propagation_barrier %arg1 allowed_direction=FORWARD : tensor<16x8xf32>
91+
%2 = sdy.propagation_barrier %arg2 allowed_direction=BACKWARD : tensor<8x16xf32>
92+
return %0, %1, %2 : tensor<8xf32>, tensor<16x8xf32>, tensor<8x16xf32>
93+
}
94+
95+
// CHECK-LABEL: func @named_computation
96+
func.func @named_computation(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) {
97+
// CHECK-NEXT: %0:2 = sdy.named_computation<"foo">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) {
98+
// CHECK-NEXT: sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32>
99+
// CHECK-NEXT: } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
100+
%0:2 = sdy.named_computation<"foo">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) {
101+
sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32>
102+
} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
103+
return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32>
104+
}
105+
106+
// CHECK-LABEL: func @tensor_sharding
107+
func.func @tensor_sharding(%arg0 : tensor<8x8xf32>, %arg1 : tensor<8x8xf32>) -> (tensor<64xf32>, tensor<8x8xf32>) {
108+
// CHECK-NEXT: stablehlo.custom_call @bar(%arg0, %arg1)
109+
// CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@mesh_xy, [{"x", "y"}]>, <@mesh_xy, [{"x"}p0, {"y":(1)2}p123]>]>
110+
%0:2 = stablehlo.custom_call @bar(%arg0, %arg1)
111+
{sdy.sharding = #sdy.sharding_per_value<[<@mesh_xy, [{"x", "y"}]>, <@mesh_xy, [{"x"}p0, {"y":(1)2}p123]>]>}
112+
: (tensor<8x8xf32>, tensor<8x8xf32>) -> (tensor<64xf32>, tensor<8x8xf32>)
113+
return %0#0, %0#1 : tensor<64xf32>, tensor<8x8xf32>
114+
}
115+
116+
// CHECK-LABEL: func @tensor_sharding_on_parameter_result
117+
// CHECK-SAME{LITERAL}: (%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}p2]>}) -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>})
118+
func.func @tensor_sharding_on_parameter_result(%arg0 : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}p2]>})
119+
-> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) {
120+
%0 = stablehlo.custom_call @foo(%arg0) : (tensor<8x8xf32>) -> (tensor<64xf32>)
121+
return %0 : tensor<64xf32>
122+
}
123+
124+
// CHECK-LABEL: func @tensor_sharding_scalar
125+
// CHECK-SAME{LITERAL}: (%arg0: tensor<f32> {sdy.sharding = #sdy.sharding<@mesh_xy, []>}) -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>})
126+
func.func @tensor_sharding_scalar(%arg0 : tensor<f32> {sdy.sharding = #sdy.sharding<@mesh_xy, []>})
127+
-> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) {
128+
%0 = stablehlo.custom_call @foo(%arg0) : (tensor<f32>) -> (tensor<64xf32>)
129+
return %0 : tensor<64xf32>
130+
}
131+
132+
// CHECK-LABEL: func @tensor_sharding_dynamic_shape
133+
func.func @tensor_sharding_dynamic_shape(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>) {
134+
// CHECK-NEXT: stablehlo.custom_call @bar(%arg0)
135+
// CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {}], replicated={"z"}>]>
136+
%0 = stablehlo.custom_call @bar(%arg0)
137+
{sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {}], replicated={"z"}>]>}
138+
: (tensor<?x?xf32>) -> (tensor<?x?xf32>)
139+
return %0 : tensor<?x?xf32>
140+
}
141+
142+
// CHECK-LABEL: func @sharding_rule_scalar
143+
func.func @sharding_rule_scalar(%arg0: tensor<f32>) -> tensor<f32> {
144+
// CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([], [])->([]), custom>}
145+
%0 = stablehlo.custom_call @foo(%arg0, %arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([], [])->([]), custom>} :
146+
(tensor<f32>, tensor<f32>) -> tensor<f32>
147+
return %0 : tensor<f32>
148+
}
149+
150+
// CHECK-LABEL: func @sharding_rule_tensor
151+
func.func @sharding_rule_tensor(%arg0: tensor<2x4xf32>) -> tensor<8xf32> {
152+
// CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>}
153+
%0 = stablehlo.reshape %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>} : (tensor<2x4xf32>) -> tensor<8xf32>
154+
return %0 : tensor<8xf32>
155+
}
156+
157+
// CHECK-LABEL: func @sharding_rule_tensor_with_many_dimensions
158+
func.func @sharding_rule_tensor_with_many_dimensions(%arg0: tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xf32>) -> tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> {
159+
// CHECK: #sdy.op_sharding_rule<([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8, z_9, z_10])
160+
// CHECK-SAME: ->([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8z_9z_10])
161+
// CHECK-SAME: {i=2, j=2, k=2, l=2, m=2, n=2, o=2, p=2, q=2, r=2, s=2, t=2, u=2, v=2, w=2, x=2, y=2, z=2, z_1=2, z_2=2, z_3=2, z_4=2, z_5=2, z_6=2, z_7=2, z_8=2, z_9=2, z_10=2}>} :
162+
%0 = stablehlo.custom_call @foo(%arg0)
163+
{sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8, z_9, z_10])->([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8z_9z_10]) {i=2, j=2, k=2, l=2, m=2, n=2, o=2, p=2, q=2, r=2, s=2, t=2, u=2, v=2, w=2, x=2, y=2, z=2, z_1=2, z_2=2, z_3=2, z_4=2, z_5=2, z_6=2, z_7=2, z_8=2, z_9=2, z_10=2}>}
164+
: (tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xf32>) -> tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32>
165+
return %0 : tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32>
166+
}
167+
168+
// CHECK-LABEL: func @custom_sharding_rule_custom_call
169+
func.func @custom_sharding_rule_custom_call(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> {
170+
// CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>}
171+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>} : (tensor<16x32xf32>) -> tensor<16x32xf32>
172+
func.return %0: tensor<16x32xf32>
173+
}
Binary file not shown.

shardy/lit.cfg.py

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
tools = [
3737
'FileCheck',
3838
'sdy_opt',
39+
'sdy_translate',
3940
]
4041
tool_dirs = [
4142
config.llvm_tools_dir,

shardy/tools/BUILD

+21
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,24 @@ cc_binary(
1515
"@llvm-project//mlir:QuantOps",
1616
],
1717
)
18+
19+
cc_binary(
20+
name = "sdy_translate",
21+
srcs = ["sdy_translate_main.cc"],
22+
deps = [
23+
"//shardy/dialect/sdy/ir:dialect",
24+
"//shardy/dialect/sdy/ir:register",
25+
"//shardy/dialect/sdy/transforms:passes",
26+
"@llvm-project//llvm:Support",
27+
"@llvm-project//mlir:AllPassesAndDialects",
28+
"@llvm-project//mlir:BytecodeWriter",
29+
"@llvm-project//mlir:FuncExtensions",
30+
"@llvm-project//mlir:IR",
31+
"@llvm-project//mlir:MlirOptLib",
32+
"@llvm-project//mlir:Parser",
33+
"@llvm-project//mlir:QuantOps",
34+
"@llvm-project//mlir:Support",
35+
"@llvm-project//mlir:Transforms",
36+
"@llvm-project//mlir:TranslateLib",
37+
],
38+
)

shardy/tools/sdy_translate_main.cc

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright 2025 The Shardy Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// MLIR `translate` tool for allowing SDY dialect bytecode emission.
17+
//
18+
// Usage:
19+
// sdy_translate <file.mlir> -serialize
20+
// sdy_translate <file.mlir.bc> -deserialize
21+
22+
#include "llvm/Support/CommandLine.h"
23+
#include "llvm/Support/LogicalResult.h"
24+
#include "mlir/Bytecode/BytecodeWriter.h"
25+
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
26+
#include "mlir/Dialect/Quant/IR/Quant.h"
27+
#include "mlir/IR/BuiltinOps.h"
28+
#include "mlir/IR/DialectRegistry.h"
29+
#include "mlir/IR/MLIRContext.h"
30+
#include "mlir/Parser/Parser.h"
31+
#include "mlir/Support/LLVM.h"
32+
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
33+
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
34+
#include "mlir/Tools/mlir-translate/Translation.h"
35+
#include "mlir/Transforms/Passes.h"
36+
#include "shardy/dialect/sdy/ir/dialect.h"
37+
#include "shardy/dialect/sdy/ir/register.h"
38+
39+
namespace mlir {
40+
41+
namespace {
42+
llvm::cl::opt<bool> stripDebuginfoOption(
43+
"strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"),
44+
llvm::cl::init(false));
45+
46+
void registerDialectsForSdy(DialectRegistry &registry) {
47+
mlir::sdy::registerAllDialects(registry);
48+
registry.insert<mlir::quant::QuantDialect>();
49+
}
50+
51+
TranslateFromMLIRRegistration serializeRegistration(
52+
"serialize", "Serialize SDY program into a portable artifact",
53+
[](mlir::ModuleOp module, llvm::raw_ostream &os) -> llvm::LogicalResult {
54+
if (stripDebuginfoOption) {
55+
PassManager pm(module->getContext());
56+
pm.addPass(createStripDebugInfoPass());
57+
if (failed(pm.run(module)))
58+
return module.emitError("failed to strip debuginfo");
59+
}
60+
const auto *producer = "SDY";
61+
BytecodeWriterConfig writerConfig(producer);
62+
return writeBytecodeToFile(module, os, writerConfig);
63+
},
64+
[](DialectRegistry &registry) { registerDialectsForSdy(registry); });
65+
66+
TranslateToMLIRRegistration deserializeRegistration(
67+
"deserialize", "Deserialize a portable artifact into a SDY program",
68+
[](llvm::StringRef input, mlir::MLIRContext *context) {
69+
context->loadDialect<sdy::SdyDialect>();
70+
auto module = parseSourceString<ModuleOp>(input, context);
71+
return module;
72+
},
73+
[](DialectRegistry &registry) { registerDialectsForSdy(registry); });
74+
} // namespace
75+
76+
} // namespace mlir
77+
78+
int main(int argc, char **argv) {
79+
return mlir::asMainReturnCode(
80+
mlir::mlirTranslateMain(argc, argv, "SDY transformation driver\n"));
81+
}

0 commit comments

Comments
 (0)