Skip to content

Commit 2d20abb

Browse files
[SDY] Add equi-sharding op ShardingGroupOp parsing and printing.
go/sdy-equi-sharding PiperOrigin-RevId: 651172814
1 parent b6e23b1 commit 2d20abb

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

shardy/dialect/sdy/ir/enums.td

+16
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,20 @@ def Sdy_PropagationDirection :
3838
let cppNamespace = Sdy_Dialect.cppNamespace;
3939
}
4040

41+
// A sharding group type specifies the type of equi-sharding behavior applied
42+
// to a sharding group. This can be either
43+
44+
// * `AS`: all tensors in the group are strictly enforced to have the same
45+
// sharding
46+
// * `LIKE`: tensors in the group are encouraged to share shardings post
47+
// propagation but the sameness of their shardings is not gauranteed.
48+
def Sdy_ShardingGroupType :
49+
I32EnumAttr<"ShardingGroupType",
50+
"sharding group type enum", [
51+
I32EnumAttrCase<"AS", 0>,
52+
I32EnumAttrCase<"LIKE", 1>]> {
53+
let genSpecializedAttr = 1;
54+
let cppNamespace = Sdy_Dialect.cppNamespace;
55+
}
56+
4157
#endif // SDY_ENUMS

shardy/dialect/sdy/ir/ops.td

+28
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,34 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation",
187187
}];
188188
}
189189

190+
def Sdy_ShardingGroupOp : Sdy_Op<"sharding_group",
191+
// The side-effects of this op are idempotent but since the op doesn't return
192+
// the input tensor we don't mark it with the Idempotent op trait.
193+
[NativeOpTrait<"ZeroResults">]>{
194+
let summary = "Sharding group operation";
195+
let description = [{
196+
This op provides an interface to assign tensors to sharding groups
197+
allowing users to have greater control over the similarity of shardings in
198+
such groups. This operation takes arguments like group ID and sharding type
199+
(LIKE or AS) and returns no result, but instead modifies the internal shard
200+
group representation to guide the propagation process to meet the
201+
"shard type" constraint specified.
202+
203+
Reference: go/sdy-equi-shard
204+
}];
205+
206+
let arguments = (ins
207+
AnyRankedTensor:$input,
208+
I64Attr:$group_id,
209+
Sdy_ShardingGroupType:$group_type,
210+
OptionalAttr<Sdy_TensorSharding>:$sharding);
211+
212+
// Dangling op has no results.
213+
let results = (outs);
214+
215+
let assemblyFormat = "$input `group_id````=```$group_id `type````=```$group_type (`sharding````=``` $sharding^)? attr-dict `:` type($input)";
216+
}
217+
190218
def Sdy_ConstantOp : Sdy_Op<"constant",
191219
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
192220
let summary = "Constant operation";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: sdy_opt %s 2>&1 | FileCheck %s
2+
3+
sdy.mesh @mesh = <"a"=3>
4+
5+
// CHECK-LABEL: func @add_to_shard_like
6+
func.func @add_to_shard_like(%arg0: tensor<8xf32>) {
7+
// sdy.sharding_group %arg0 group_id=12 type=LIKE : tensor<8xf32>
8+
sdy.sharding_group %arg0 group_id=12 type=LIKE : tensor<8xf32>
9+
func.return
10+
}
11+
12+
// CHECK-LABEL: func @add_to_shard_as
13+
func.func @add_to_shard_as(%arg0: tensor<8xf32>) {
14+
// CHECK sdy.sharding_group %arg0 group_id=21 type=AS : tensor<8xf32>
15+
sdy.sharding_group %arg0 group_id=21 type=AS : tensor<8xf32>
16+
func.return
17+
}
18+
19+
// CHECK-LABEL: func @add_to_shard_as_with_sharding
20+
func.func @add_to_shard_as_with_sharding(%arg0: tensor<8xf32>) {
21+
// CHECK sdy.sharding_group %arg0 group_id=21 type=AS sharding=<@mesh, [{"a"}]> : tensor<8xf32>
22+
sdy.sharding_group %arg0 group_id=21 type=AS sharding=<@mesh, [{"a"}]> : tensor<8xf32>
23+
func.return
24+
}

0 commit comments

Comments
 (0)