Skip to content

Commit bbe504a

Browse files
ZixuanJiangcopybara-github
authored andcommitted
Add strictPrefixOf method to AxisRefAttr.
"a.strictPrefixOf(b)" is equivalent to "a.prefixOf(b) && a != b" This method will be used in propagation. PiperOrigin-RevId: 656553357
1 parent 8f92b38 commit bbe504a

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

shardy/dialect/sdy/ir/attrs.td

+14
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,26 @@ def Sdy_AxisRef : AttrDef<Sdy_Dialect, "AxisRef"> {
181181
// For example:
182182
// "a":(1)2, "a" -> true
183183
// "a":(2)2, "a":(2)4 -> true
184+
// "a", "a" -> true
184185
// "a":(2)4, "a":(2)4 -> true
185186
// "a":(1)4, "a":(1)2 -> false
186187
// "a":(1)4, "a":(2)8 -> false
187188
// "a":(1)2, "b" -> false
188189
bool prefixOf(AxisRefAttr other) const;
189190

191+
// Returns whether this axis or sub-axis is a strict prefix of `other`.
192+
// "a.strictPrefixOf(b)" is equivalent to "a.prefixOf(b) && a != b".
193+
//
194+
// For example:
195+
// "a":(1)2, "a" -> true
196+
// "a":(2)2, "a":(2)4 -> true
197+
// "a", "a" -> false
198+
// "a":(2)4, "a":(2)4 -> false
199+
// "a":(1)4, "a":(1)2 -> false
200+
// "a":(1)4, "a":(2)8 -> false
201+
// "a":(1)2, "b" -> false
202+
bool strictPrefixOf(AxisRefAttr other) const;
203+
190204
// Returns whether this axis or sub-axis overlaps with `other`, i.e., they
191205
// are equal or there is a sub-axis that is contained in both axis refs.
192206
//

shardy/dialect/sdy/ir/dialect.cc

+4
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ bool AxisRefAttr::prefixOf(AxisRefAttr other) const {
183183
getSubAxisPreSize() == other.getSubAxisPreSize();
184184
}
185185

186+
bool AxisRefAttr::strictPrefixOf(AxisRefAttr other) const {
187+
return prefixOf(other) && *this != other;
188+
}
189+
186190
bool AxisRefAttr::overlaps(AxisRefAttr other) const {
187191
if (other.getName() != getName()) {
188192
return false;

shardy/dialect/sdy/ir/dialect_test.cc

+30-12
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,36 @@ TEST_F(DialectTest, AxisRefAttrContains) {
102102
}
103103

104104
TEST_F(DialectTest, AxisRefAttrPrefixOf) {
105-
EXPECT_TRUE(createAxis("x").prefixOf(createAxis("x")));
106-
EXPECT_TRUE(createSubAxis("x", 1, 4).prefixOf(createAxis("x")));
107-
EXPECT_TRUE(createSubAxis("x", 1, 2).prefixOf(createSubAxis("x", 1, 4)));
108-
EXPECT_TRUE(createSubAxis("x", 2, 4).prefixOf(createSubAxis("x", 2, 4)));
109-
EXPECT_TRUE(createSubAxis("x", 2, 2).prefixOf(createSubAxis("x", 2, 8)));
110-
111-
EXPECT_FALSE(createAxis("x").prefixOf(createAxis("y")));
112-
EXPECT_FALSE(createSubAxis("x", 1, 2).prefixOf(createAxis("y")));
113-
EXPECT_FALSE(createAxis("x").prefixOf(createSubAxis("x", 1, 2)));
114-
EXPECT_FALSE(createSubAxis("x", 1, 4).prefixOf(createSubAxis("x", 1, 2)));
115-
EXPECT_FALSE(createSubAxis("x", 1, 4).prefixOf(createSubAxis("x", 4, 2)));
116-
EXPECT_FALSE(createSubAxis("x", 1, 4).prefixOf(createSubAxis("x", 2, 4)));
105+
auto strictPrefixOf = [](AxisRefAttr a, AxisRefAttr b) {
106+
EXPECT_TRUE(a.strictPrefixOf(b));
107+
EXPECT_TRUE(a.prefixOf(b));
108+
EXPECT_FALSE(b.prefixOf(a));
109+
EXPECT_FALSE(b.strictPrefixOf(a));
110+
};
111+
strictPrefixOf(createSubAxis("x", 1, 4), createAxis("x"));
112+
strictPrefixOf(createSubAxis("x", 1, 2), createSubAxis("x", 1, 4));
113+
strictPrefixOf(createSubAxis("x", 2, 2), createSubAxis("x", 2, 8));
114+
115+
auto equals = [](AxisRefAttr a, AxisRefAttr b) {
116+
EXPECT_TRUE(a == b);
117+
EXPECT_TRUE(a.prefixOf(b));
118+
EXPECT_TRUE(b.prefixOf(a));
119+
EXPECT_FALSE(a.strictPrefixOf(b));
120+
EXPECT_FALSE(b.strictPrefixOf(a));
121+
};
122+
equals(createAxis("x"), createAxis("x"));
123+
equals(createSubAxis("x", 2, 4), createSubAxis("x", 2, 4));
124+
125+
auto isNotPrefix = [](AxisRefAttr a, AxisRefAttr b) {
126+
EXPECT_FALSE(a.prefixOf(b));
127+
EXPECT_FALSE(b.prefixOf(a));
128+
EXPECT_FALSE(a.strictPrefixOf(b));
129+
EXPECT_FALSE(b.strictPrefixOf(a));
130+
};
131+
isNotPrefix(createAxis("x"), createAxis("y"));
132+
isNotPrefix(createSubAxis("x", 1, 2), createAxis("y"));
133+
isNotPrefix(createSubAxis("x", 1, 4), createSubAxis("x", 4, 2));
134+
isNotPrefix(createSubAxis("x", 1, 4), createSubAxis("x", 2, 4));
117135
}
118136

119137
TEST_F(DialectTest, AxisRefAttrOverlaps) {

0 commit comments

Comments
 (0)