diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir index 0b959915d6bbb..e2088817c5204 100644 --- a/mlir/test/IR/region.mlir +++ b/mlir/test/IR/region.mlir @@ -106,3 +106,10 @@ func.func @named_region_has_wrong_number_of_blocks() { test.single_no_terminator_custom_asm_op { "important_dont_drop"() : () -> () } + +// ----- + +// CHECK: test.dummy_op_with_region_ref +test.dummy_op_with_region_ref { + ^bb0: +} diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp index 9ed1b3a47be36..70bab21b83256 100644 --- a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp @@ -381,3 +381,26 @@ void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type, Attribute attr) { printer.printAttributeWithoutType(attr); } + +//===----------------------------------------------------------------------===// +// CustomDirectiveDummyRegionRef +//===----------------------------------------------------------------------===// + +ParseResult test::parseDummyRegionRef(OpAsmParser &parser, Region ®ion) { + return success(); +} + +void test::printDummyRegionRef(OpAsmPrinter &printer, Operation *op, + Region ®ion) { /* do nothing */ } + +//===----------------------------------------------------------------------===// +// CustomDirectiveDummySuccessorRef +//===----------------------------------------------------------------------===// + +ParseResult test::parseDummySuccessorRef(OpAsmParser &parser, + Block *successor) { + return success(); +} + +void test::printDummySuccessorRef(OpAsmPrinter &printer, Operation *op, + Block *successor) { /* do nothing */ } diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h index 6d4df7d82ffa5..e914f9a27b79b 100644 --- a/mlir/test/lib/Dialect/Test/TestFormatUtils.h +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h @@ -207,6 +207,24 @@ mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser, void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op, mlir::TypeAttr type, mlir::Attribute attr); +//===----------------------------------------------------------------------===// +// CustomDirectiveDummyRegionRef +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseDummyRegionRef(mlir::OpAsmParser &parser, + mlir::Region ®ion); +void printDummyRegionRef(mlir::OpAsmPrinter &printer, mlir::Operation *op, + mlir::Region ®ion); + +//===----------------------------------------------------------------------===// +// CustomDirectiveDummySuccessorRef +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseDummySuccessorRef(mlir::OpAsmParser &parser, + mlir::Block *successor); +void printDummySuccessorRef(mlir::OpAsmPrinter &printer, mlir::Operation *op, + mlir::Block *successor); + } // end namespace test #endif // MLIR_TESTFORMATUTILS_H diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 1c961d272f192..0ad5bfa9a58ab 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3665,4 +3665,22 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { ); } +//===----------------------------------------------------------------------===// +// Test assembly format references +//===----------------------------------------------------------------------===// + +def TestOpWithRegionRef : TEST_Op<"dummy_op_with_region_ref", [NoTerminator]> { + let regions = (region AnyRegion:$body); + let assemblyFormat = [{ + $body attr-dict custom(ref($body)) + }]; +} + +def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> { + let successors = (successor AnySuccessor:$successor); + let assemblyFormat = [{ + $successor attr-dict custom(ref($successor)) + }]; +} + #endif // TEST_OPS diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td index 02bf65609b21a..03b63f42c7767 100644 --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -49,6 +49,19 @@ def DirectiveCustomValidD : TestFormat_Op<[{ def DirectiveCustomValidE : TestFormat_Op<[{ custom(prop-dict) attr-dict }]>, Arguments<(ins UnitAttr:$flag)>; +def DirectiveCustomValidF : TestFormat_Op<[{ + $operand custom(ref($operand)) attr-dict +}]>, Arguments<(ins Optional:$operand)>; +def DirectiveCustomValidG : TestFormat_Op<[{ + $body custom(ref($body)) attr-dict +}]> { + let regions = (region AnyRegion:$body); +} +def DirectiveCustomValidH : TestFormat_Op<[{ + $successor custom(ref($successor)) attr-dict +}]> { + let successors = (successor AnySuccessor:$successor); +} //===----------------------------------------------------------------------===// // functional-type diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td index 09e068b91a40b..1790737a3a349 100644 --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -109,3 +109,23 @@ def OptionalGroupC : TestFormat_Op<[{ def OptionalGroupD : TestFormat_Op<[{ (custom($a, $b)^)? attr-dict }], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$a, Optional:$b)>; + +// CHECK-LABEL: RegionRef::parse +// CHECK: auto odsResult = parseCustom(parser, *bodyRegion); +// CHECK-LABEL: RegionRef::print +// CHECK: printCustom(_odsPrinter, *this, getBody()); +def RegionRef : TestFormat_Op<[{ + $body custom(ref($body)) attr-dict +}]> { + let regions = (region AnyRegion:$body); +} + +// CHECK-LABEL: SuccessorRef::parse +// CHECK: auto odsResult = parseCustom(parser, successorSuccessor); +// CHECK-LABEL: SuccessorRef::print +// CHECK: printCustom(_odsPrinter, *this, getSuccessor()); +def SuccessorRef : TestFormat_Op<[{ + $successor custom(ref($successor)) attr-dict +}]> { + let successors = (successor AnySuccessor:$successor); +} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index d27814bc4541e..14af7787a833e 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -3376,11 +3376,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { if (hasAllRegions || !seenRegions.insert(region).second) return emitError(loc, "region '" + name + "' is already bound"); - } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) { - return emitError(loc, "region '" + name + - "' must be bound before it is referenced"); + } else if (ctx == RefDirectiveContext) { + if (!seenRegions.count(region)) + return emitError(loc, "region '" + name + + "' must be bound before it is referenced"); } else { - return emitError(loc, "regions can only be used at the top level"); + return emitError(loc, "regions can only be used at the top level " + "or in a ref directive"); } return create(region); } @@ -3396,11 +3398,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { if (hasAllSuccessors || !seenSuccessors.insert(successor).second) return emitError(loc, "successor '" + name + "' is already bound"); - } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) { - return emitError(loc, "successor '" + name + - "' must be bound before it is referenced"); + } else if (ctx == RefDirectiveContext) { + if (!seenSuccessors.count(successor)) + return emitError(loc, "successor '" + name + + "' must be bound before it is referenced"); } else { - return emitError(loc, "successors can only be used at the top level"); + return emitError(loc, "successors can only be used at the top level " + "or in a ref directive"); } return create(successor);