diff --git a/shardy/dialect/sdy/ir/attrs.td b/shardy/dialect/sdy/ir/attrs.td index ce29e4e3..a3226a8b 100644 --- a/shardy/dialect/sdy/ir/attrs.td +++ b/shardy/dialect/sdy/ir/attrs.td @@ -181,12 +181,26 @@ def Sdy_AxisRef : AttrDef { // For example: // "a":(1)2, "a" -> true // "a":(2)2, "a":(2)4 -> true + // "a", "a" -> true // "a":(2)4, "a":(2)4 -> true // "a":(1)4, "a":(1)2 -> false // "a":(1)4, "a":(2)8 -> false // "a":(1)2, "b" -> false bool prefixOf(AxisRefAttr other) const; + // Returns whether this axis or sub-axis is a strict prefix of `other`. + // "a.strictPrefixOf(b)" is equivalent to "a.prefixOf(b) && a != b". + // + // For example: + // "a":(1)2, "a" -> true + // "a":(2)2, "a":(2)4 -> true + // "a", "a" -> false + // "a":(2)4, "a":(2)4 -> false + // "a":(1)4, "a":(1)2 -> false + // "a":(1)4, "a":(2)8 -> false + // "a":(1)2, "b" -> false + bool strictPrefixOf(AxisRefAttr other) const; + // Returns whether this axis or sub-axis overlaps with `other`, i.e., they // are equal or there is a sub-axis that is contained in both axis refs. // diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc index 64482df6..f6f88bc5 100644 --- a/shardy/dialect/sdy/ir/dialect.cc +++ b/shardy/dialect/sdy/ir/dialect.cc @@ -183,6 +183,10 @@ bool AxisRefAttr::prefixOf(AxisRefAttr other) const { getSubAxisPreSize() == other.getSubAxisPreSize(); } +bool AxisRefAttr::strictPrefixOf(AxisRefAttr other) const { + return prefixOf(other) && *this != other; +} + bool AxisRefAttr::overlaps(AxisRefAttr other) const { if (other.getName() != getName()) { return false; diff --git a/shardy/dialect/sdy/ir/dialect_test.cc b/shardy/dialect/sdy/ir/dialect_test.cc index 5cecdf40..6c5c595e 100644 --- a/shardy/dialect/sdy/ir/dialect_test.cc +++ b/shardy/dialect/sdy/ir/dialect_test.cc @@ -102,18 +102,36 @@ TEST_F(DialectTest, AxisRefAttrContains) { } TEST_F(DialectTest, AxisRefAttrPrefixOf) { - EXPECT_TRUE(createAxis("x").prefixOf(createAxis("x"))); - EXPECT_TRUE(createSubAxis("x", 1, 4).prefixOf(createAxis("x"))); - EXPECT_TRUE(createSubAxis("x", 1, 2).prefixOf(createSubAxis("x", 1, 4))); - EXPECT_TRUE(createSubAxis("x", 2, 4).prefixOf(createSubAxis("x", 2, 4))); - EXPECT_TRUE(createSubAxis("x", 2, 2).prefixOf(createSubAxis("x", 2, 8))); - - EXPECT_FALSE(createAxis("x").prefixOf(createAxis("y"))); - EXPECT_FALSE(createSubAxis("x", 1, 2).prefixOf(createAxis("y"))); - EXPECT_FALSE(createAxis("x").prefixOf(createSubAxis("x", 1, 2))); - EXPECT_FALSE(createSubAxis("x", 1, 4).prefixOf(createSubAxis("x", 1, 2))); - EXPECT_FALSE(createSubAxis("x", 1, 4).prefixOf(createSubAxis("x", 4, 2))); - EXPECT_FALSE(createSubAxis("x", 1, 4).prefixOf(createSubAxis("x", 2, 4))); + auto strictPrefixOf = [](AxisRefAttr a, AxisRefAttr b) { + EXPECT_TRUE(a.strictPrefixOf(b)); + EXPECT_TRUE(a.prefixOf(b)); + EXPECT_FALSE(b.prefixOf(a)); + EXPECT_FALSE(b.strictPrefixOf(a)); + }; + strictPrefixOf(createSubAxis("x", 1, 4), createAxis("x")); + strictPrefixOf(createSubAxis("x", 1, 2), createSubAxis("x", 1, 4)); + strictPrefixOf(createSubAxis("x", 2, 2), createSubAxis("x", 2, 8)); + + auto equals = [](AxisRefAttr a, AxisRefAttr b) { + EXPECT_TRUE(a == b); + EXPECT_TRUE(a.prefixOf(b)); + EXPECT_TRUE(b.prefixOf(a)); + EXPECT_FALSE(a.strictPrefixOf(b)); + EXPECT_FALSE(b.strictPrefixOf(a)); + }; + equals(createAxis("x"), createAxis("x")); + equals(createSubAxis("x", 2, 4), createSubAxis("x", 2, 4)); + + auto isNotPrefix = [](AxisRefAttr a, AxisRefAttr b) { + EXPECT_FALSE(a.prefixOf(b)); + EXPECT_FALSE(b.prefixOf(a)); + EXPECT_FALSE(a.strictPrefixOf(b)); + EXPECT_FALSE(b.strictPrefixOf(a)); + }; + isNotPrefix(createAxis("x"), createAxis("y")); + isNotPrefix(createSubAxis("x", 1, 2), createAxis("y")); + isNotPrefix(createSubAxis("x", 1, 4), createSubAxis("x", 4, 2)); + isNotPrefix(createSubAxis("x", 1, 4), createSubAxis("x", 2, 4)); } TEST_F(DialectTest, AxisRefAttrOverlaps) {