Skip to content

Commit a574831

Browse files
committed
[mlir][tblgen] Fix region and successor references in custom directives
Previously, references to regions and successors were incorrectly disallowed outside the top-level assembly form. This change enables the use of bound regions and successors as variables in custom directives.
1 parent 026aae7 commit a574831

File tree

7 files changed

+111
-8
lines changed

7 files changed

+111
-8
lines changed

mlir/test/IR/region.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,10 @@ func.func @named_region_has_wrong_number_of_blocks() {
106106
test.single_no_terminator_custom_asm_op {
107107
"important_dont_drop"() : () -> ()
108108
}
109+
110+
// -----
111+
112+
// CHECK: dummy_op_with_region_ref
113+
test.dummy_op_with_region_ref() ({
114+
^bb0:
115+
}) : () -> ()

mlir/test/lib/Dialect/Test/TestFormatUtils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,26 @@ void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
381381
Attribute attr) {
382382
printer.printAttributeWithoutType(attr);
383383
}
384+
385+
//===----------------------------------------------------------------------===//
386+
// CustomDirectiveDummyRegionRef
387+
//===----------------------------------------------------------------------===//
388+
389+
ParseResult test::parseDummyRegionRef(OpAsmParser &parser, Region &region) {
390+
return success();
391+
}
392+
393+
void test::printDummyRegionRef(OpAsmPrinter &printer, Operation *op,
394+
Region &region) { /* do nothing */ }
395+
396+
//===----------------------------------------------------------------------===//
397+
// CustomDirectiveDummySuccessorRef
398+
//===----------------------------------------------------------------------===//
399+
400+
ParseResult test::parseDummySuccessorRef(OpAsmParser &parser,
401+
Block *successor) {
402+
return success();
403+
}
404+
405+
void test::printDummySuccessorRef(OpAsmPrinter &printer, Operation *op,
406+
Block *successor) { /* do nothing */ }

mlir/test/lib/Dialect/Test/TestFormatUtils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,24 @@ mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
207207
void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
208208
mlir::TypeAttr type, mlir::Attribute attr);
209209

210+
//===----------------------------------------------------------------------===//
211+
// CustomDirectiveDummyRegionRef
212+
//===----------------------------------------------------------------------===//
213+
214+
mlir::ParseResult parseDummyRegionRef(mlir::OpAsmParser &parser,
215+
mlir::Region &region);
216+
void printDummyRegionRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
217+
mlir::Region &region);
218+
219+
//===----------------------------------------------------------------------===//
220+
// CustomDirectiveDummySuccessorRef
221+
//===----------------------------------------------------------------------===//
222+
223+
mlir::ParseResult parseDummySuccessorRef(mlir::OpAsmParser &parser,
224+
mlir::Block *successor);
225+
void printDummySuccessorRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
226+
mlir::Block *successor);
227+
210228
} // end namespace test
211229

212230
#endif // MLIR_TESTFORMATUTILS_H

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3665,4 +3665,22 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
36653665
);
36663666
}
36673667

3668+
//===----------------------------------------------------------------------===//
3669+
// Test assembly format references
3670+
//===----------------------------------------------------------------------===//
3671+
3672+
def TestOpWithRegionRef : TEST_Op<"dummy_op_with_region_ref"> {
3673+
let regions = (region AnyRegion:$body);
3674+
let assemblyFormat = [{
3675+
$body attr-dict custom<DummyRegionRef>(ref($body))
3676+
}];
3677+
}
3678+
3679+
def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
3680+
let successors = (successor AnySuccessor:$successor);
3681+
let assemblyFormat = [{
3682+
$successor attr-dict custom<DummySuccessorRef>(ref($successor))
3683+
}];
3684+
}
3685+
36683686
#endif // TEST_OPS

mlir/test/mlir-tblgen/op-format-spec.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ def DirectiveCustomValidD : TestFormat_Op<[{
4949
def DirectiveCustomValidE : TestFormat_Op<[{
5050
custom<MyDirective>(prop-dict) attr-dict
5151
}]>, Arguments<(ins UnitAttr:$flag)>;
52+
def DirectiveCustomValidF : TestFormat_Op<[{
53+
$operand custom<MyDirective>(ref($operand)) attr-dict
54+
}]>, Arguments<(ins Optional<I64>:$operand)>;
55+
def DirectiveCustomValidG : TestFormat_Op<[{
56+
$body custom<MyDirective>(ref($body)) attr-dict
57+
}]> {
58+
let regions = (region AnyRegion:$body);
59+
}
60+
def DirectiveCustomValidH : TestFormat_Op<[{
61+
$successor custom<MyDirective>(ref($successor)) attr-dict
62+
}]> {
63+
let successors = (successor AnySuccessor:$successor);
64+
}
5265

5366
//===----------------------------------------------------------------------===//
5467
// functional-type

mlir/test/mlir-tblgen/op-format.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,23 @@ def OptionalGroupC : TestFormat_Op<[{
109109
def OptionalGroupD : TestFormat_Op<[{
110110
(custom<Custom>($a, $b)^)? attr-dict
111111
}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<I64>:$a, Optional<I64>:$b)>;
112+
113+
// CHECK-LABEL: RegionRef::parse
114+
// CHECK: auto odsResult = parseCustom(parser, *bodyRegion);
115+
// CHECK-LABEL: RegionRef::print
116+
// CHECK: printCustom(_odsPrinter, *this, getBody());
117+
def RegionRef : TestFormat_Op<[{
118+
$body custom<Custom>(ref($body)) attr-dict
119+
}]> {
120+
let regions = (region AnyRegion:$body);
121+
}
122+
123+
// CHECK-LABEL: SuccessorRef::parse
124+
// CHECK: auto odsResult = parseCustom(parser, successorSuccessor);
125+
// CHECK-LABEL: SuccessorRef::print
126+
// CHECK: printCustom(_odsPrinter, *this, getSuccessor());
127+
def SuccessorRef : TestFormat_Op<[{
128+
$successor custom<Custom>(ref($successor)) attr-dict
129+
}]> {
130+
let successors = (successor AnySuccessor:$successor);
131+
}

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,11 +3376,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
33763376
if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
33773377
if (hasAllRegions || !seenRegions.insert(region).second)
33783378
return emitError(loc, "region '" + name + "' is already bound");
3379-
} else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
3380-
return emitError(loc, "region '" + name +
3381-
"' must be bound before it is referenced");
3379+
} else if (ctx == RefDirectiveContext) {
3380+
if (!seenRegions.count(region))
3381+
return emitError(loc, "region '" + name +
3382+
"' must be bound before it is referenced");
33823383
} else {
3383-
return emitError(loc, "regions can only be used at the top level");
3384+
return emitError(loc, "regions can only be used at the top level "
3385+
"or in a ref directive");
33843386
}
33853387
return create<RegionVariable>(region);
33863388
}
@@ -3396,11 +3398,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
33963398
if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
33973399
if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
33983400
return emitError(loc, "successor '" + name + "' is already bound");
3399-
} else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
3400-
return emitError(loc, "successor '" + name +
3401-
"' must be bound before it is referenced");
3401+
} else if (ctx == RefDirectiveContext) {
3402+
if (!seenSuccessors.count(successor))
3403+
return emitError(loc, "successor '" + name +
3404+
"' must be bound before it is referenced");
34023405
} else {
3403-
return emitError(loc, "successors can only be used at the top level");
3406+
return emitError(loc, "successors can only be used at the top level "
3407+
"or in a ref directive");
34043408
}
34053409

34063410
return create<SuccessorVariable>(successor);

0 commit comments

Comments
 (0)