Skip to content

[mlir][tblgen] Fix region and successor references in custom directives #146242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/test/IR/region.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,10 @@ func.func @named_region_has_wrong_number_of_blocks() {
test.single_no_terminator_custom_asm_op {
"important_dont_drop"() : () -> ()
}

// -----

// CHECK: dummy_op_with_region_ref
test.dummy_op_with_region_ref() ({
^bb0:
}) : () -> ()
23 changes: 23 additions & 0 deletions mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,26 @@ void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
Attribute attr) {
printer.printAttributeWithoutType(attr);
}

//===----------------------------------------------------------------------===//
// CustomDirectiveDummyRegionRef
//===----------------------------------------------------------------------===//

ParseResult test::parseDummyRegionRef(OpAsmParser &parser, Region &region) {
return success();
}

void test::printDummyRegionRef(OpAsmPrinter &printer, Operation *op,
Region &region) { /* do nothing */ }

//===----------------------------------------------------------------------===//
// CustomDirectiveDummySuccessorRef
//===----------------------------------------------------------------------===//

ParseResult test::parseDummySuccessorRef(OpAsmParser &parser,
Block *successor) {
return success();
}

void test::printDummySuccessorRef(OpAsmPrinter &printer, Operation *op,
Block *successor) { /* do nothing */ }
18 changes: 18 additions & 0 deletions mlir/test/lib/Dialect/Test/TestFormatUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
mlir::TypeAttr type, mlir::Attribute attr);

//===----------------------------------------------------------------------===//
// CustomDirectiveDummyRegionRef
//===----------------------------------------------------------------------===//

mlir::ParseResult parseDummyRegionRef(mlir::OpAsmParser &parser,
mlir::Region &region);
void printDummyRegionRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
mlir::Region &region);

//===----------------------------------------------------------------------===//
// CustomDirectiveDummySuccessorRef
//===----------------------------------------------------------------------===//

mlir::ParseResult parseDummySuccessorRef(mlir::OpAsmParser &parser,
mlir::Block *successor);
void printDummySuccessorRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
mlir::Block *successor);

} // end namespace test

#endif // MLIR_TESTFORMATUTILS_H
18 changes: 18 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3665,4 +3665,22 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
);
}

//===----------------------------------------------------------------------===//
// Test assembly format references
//===----------------------------------------------------------------------===//

def TestOpWithRegionRef : TEST_Op<"dummy_op_with_region_ref"> {
let regions = (region AnyRegion:$body);
let assemblyFormat = [{
$body attr-dict custom<DummyRegionRef>(ref($body))
}];
}

def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
let successors = (successor AnySuccessor:$successor);
let assemblyFormat = [{
$successor attr-dict custom<DummySuccessorRef>(ref($successor))
}];
}

#endif // TEST_OPS
13 changes: 13 additions & 0 deletions mlir/test/mlir-tblgen/op-format-spec.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ def DirectiveCustomValidD : TestFormat_Op<[{
def DirectiveCustomValidE : TestFormat_Op<[{
custom<MyDirective>(prop-dict) attr-dict
}]>, Arguments<(ins UnitAttr:$flag)>;
def DirectiveCustomValidF : TestFormat_Op<[{
$operand custom<MyDirective>(ref($operand)) attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
def DirectiveCustomValidG : TestFormat_Op<[{
$body custom<MyDirective>(ref($body)) attr-dict
}]> {
let regions = (region AnyRegion:$body);
}
def DirectiveCustomValidH : TestFormat_Op<[{
$successor custom<MyDirective>(ref($successor)) attr-dict
}]> {
let successors = (successor AnySuccessor:$successor);
}

//===----------------------------------------------------------------------===//
// functional-type
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/mlir-tblgen/op-format.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,23 @@ def OptionalGroupC : TestFormat_Op<[{
def OptionalGroupD : TestFormat_Op<[{
(custom<Custom>($a, $b)^)? attr-dict
}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<I64>:$a, Optional<I64>:$b)>;

// CHECK-LABEL: RegionRef::parse
// CHECK: auto odsResult = parseCustom(parser, *bodyRegion);
// CHECK-LABEL: RegionRef::print
// CHECK: printCustom(_odsPrinter, *this, getBody());
def RegionRef : TestFormat_Op<[{
$body custom<Custom>(ref($body)) attr-dict
}]> {
let regions = (region AnyRegion:$body);
}

// CHECK-LABEL: SuccessorRef::parse
// CHECK: auto odsResult = parseCustom(parser, successorSuccessor);
// CHECK-LABEL: SuccessorRef::print
// CHECK: printCustom(_odsPrinter, *this, getSuccessor());
def SuccessorRef : TestFormat_Op<[{
$successor custom<Custom>(ref($successor)) attr-dict
}]> {
let successors = (successor AnySuccessor:$successor);
}
20 changes: 12 additions & 8 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3376,11 +3376,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
if (hasAllRegions || !seenRegions.insert(region).second)
return emitError(loc, "region '" + name + "' is already bound");
} else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
return emitError(loc, "region '" + name +
"' must be bound before it is referenced");
} else if (ctx == RefDirectiveContext) {
if (!seenRegions.count(region))
return emitError(loc, "region '" + name +
"' must be bound before it is referenced");
} else {
return emitError(loc, "regions can only be used at the top level");
return emitError(loc, "regions can only be used at the top level "
"or in a ref directive");
}
return create<RegionVariable>(region);
}
Expand All @@ -3396,11 +3398,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
return emitError(loc, "successor '" + name + "' is already bound");
} else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
return emitError(loc, "successor '" + name +
"' must be bound before it is referenced");
} else if (ctx == RefDirectiveContext) {
if (!seenSuccessors.count(successor))
return emitError(loc, "successor '" + name +
"' must be bound before it is referenced");
} else {
return emitError(loc, "successors can only be used at the top level");
return emitError(loc, "successors can only be used at the top level "
"or in a ref directive");
}

return create<SuccessorVariable>(successor);
Expand Down
Loading