Skip to content

[mlir][spirv] Add support for spirv.mlir.break #138688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions mlir/docs/Dialects/SPIR-V.md
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,61 @@ For example
}
```

#### Early Exit

In the current form loops do support an early exit as any block can branch to
the merge block of the loop. However, the problem arises when such early exit
is conditional and the branch is sunk into a `spirv.mlir.selection` region.
In such structure the branch inside the selection region cannot reference block
of the loop enclosing the selection. At the same time such pattern is not unusual.
To support early loop exit within nested structured control flow, SPIR-V dialect
introduces `spirv.mlir.break` operation. The semantic of this operation is to branch
to the merge block of the first enclosing loop.

For example

```mlir
spirv.mlir.loop {
spirv.Branch ^header(%zero: i32)

^header(%i : i32):
%cmp = spirv.SLessThan %i, %count : i32
spirv.BranchConditional %cmp, ^body, ^merge_loop

^body:
%cond = spirv.SGreaterThan %i, %five : i32
spirv.Branch ^selection

^selection:
spirv.mlir.selection {
spirv.BranchConditional %cond, ^true, ^merge_sel
^true:
spirv.mlir.break // Jump to ^merge_loop. Regular branch cannot reference ^merge_loop, as it is outside the region.
^merge_sel:
spirv.mlir.merge
}

spirv.Branch ^continue

^continue:
%new_i = spirv.IAdd %i, %one : i32
spirv.Branch ^header(%new_i: i32)

^merge_loop:
spirv.mlir.merge
}
```

The equivalent GLSL or C code would be

```c
for (int i = 0; i < 10; ++i) {
x += 1;
if(x > 5)
break;
}
```

### Block argument for Phi

There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
Expand Down
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"

// -----

// TODO: This is not only specific to control flow ops, so it could be moved
// somewhere else.
class SPIRV_HasParentOfType<string op> : PredOpTrait<
"op expects to be nested in " # op,
CPred<"getOperation()->getParentOfType<::mlir::spirv::" # op # ">() != nullptr">
>;

// -----

def SPIRV_BranchOp : SPIRV_Op<"Branch", [
DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope, Pure,
Terminator]> {
Expand Down Expand Up @@ -535,4 +544,30 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
let hasRegionVerifier = 1;
}

// -----

def SPIRV_BreakOp : SPIRV_Op<"mlir.break", [
Pure, Terminator, SPIRV_HasParentOfType<"LoopOp">, ReturnLike]> {
let summary = "Early exit from a structured loop.";

let description = [{
Since the SPIR-V dialect relies on structured control flow, early exit using
branches is not possible. Since branch cannot reference blocks outside a region
a `spirv.mlir.selection` cannot arbitrarily branch to the merge block of the
enclosing loop.

To provide support for early exits dialect implements a `spirv.mlir.break`
operation. The semantic of the operation is like that in GLSL / C / C++.
The break operation should be treated as a branch to the merge block of the
enclosing loop.
}];

let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
let hasOpcode = 0;
let autogenSerialization = 0;
let hasVerifier = 0;
}

#endif // MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
44 changes: 44 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,46 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
return success();
}

LogicalResult spirv::Deserializer::handleEarlyExits() {
SmallVector<Block *> loopMergeBlocks;

// Find all blocks that are loops' merge blocks.
for (auto &[_, mergeInfo] : blockMergeInfo)
if (mergeInfo.continueBlock)
loopMergeBlocks.push_back(mergeInfo.mergeBlock);

for (auto &[header, mergeInfo] : blockMergeInfo) {
// We look for something like `if(x) break; ...` so we only process
// selection for now.
if (!mergeInfo.continueBlock) {
SetVector<Block *> constructBlocks;
constructBlocks.insert(header);

// Iterate over all blocks in the selection. This is similar to
// `collectBlocksInConstruct()` but with extra logic inserting
// `spirv.mlir.break`. We look for any block inside the selection region
// that jumps directly to the loop merge and does not go through the merge
// block of the selection. This indicates the unstructured jump so the
// branch is replaced with break.
for (unsigned i = 0; i < constructBlocks.size(); ++i) {
for (Block *successor : constructBlocks[i]->getSuccessors()) {
Block *block = constructBlocks[i];
if (llvm::is_contained(loopMergeBlocks, successor)) {
assert(!block->empty() && block->getNumSuccessors() == 1);
block->back().erase();
OpBuilder builder(block, block->end());
builder.create<spirv::BreakOp>(mergeInfo.loc);
}
if (successor != mergeInfo.mergeBlock)
constructBlocks.insert(successor);
}
}
}
}

return success();
}

LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
Expand All @@ -2361,6 +2401,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
return failure();
}

if (failed(handleEarlyExits())) {
return failure();
}

// TODO: This loop is non-deterministic. Iteration order may vary between runs
// for the same shader as the key to the map is a pointer. See:
// https://github.com/llvm/llvm-project/issues/128547
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ class Deserializer {
/// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();

/// Detect unstructured early exits from loops and replaces those arbitrary
/// branches with `spirv.mlir.break` statements.
LogicalResult handleEarlyExits();

//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,20 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
return success();
}

LogicalResult Serializer::processBreakOp(spirv::BreakOp breakOp) {
auto parentLoopOp = breakOp.getOperation()->getParentOfType<spirv::LoopOp>();

if (!parentLoopOp)
return failure();

auto *mergeBlock = parentLoopOp.getMergeBlock();
auto mergeID = getBlockID(mergeBlock);

encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {mergeID});

return success();
}

LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.getCondition());
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::BreakOp op) { return processBreakOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ class Serializer {

LogicalResult processLoopOp(spirv::LoopOp loopOp);

LogicalResult processBreakOp(spirv::BreakOp breakOp);

LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);

LogicalResult processBranchOp(spirv::BranchOp branchOp);
Expand Down
70 changes: 70 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,76 @@ func.func @loop_yield(%count : i32) -> () {

// -----

func.func @loop_break(%count : i32) -> () {
%zero = spirv.Constant 0: i32
%one = spirv.Constant 1: i32
%five = spirv.Constant 5: i32

// CHECK: spirv.mlir.loop {
spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
spirv.Branch ^header(%zero: i32)

// CHECK-NEXT: ^bb1({{%.*}}: i32):
^header(%i : i32):
%cmp = spirv.SLessThan %i, %count : i32
// CHECK: spirv.BranchConditional {{%.*}}, ^bb2, ^bb5
spirv.BranchConditional %cmp, ^body, ^merge

// CHECK-NEXT: ^bb2:
^body:
%cond = spirv.SGreaterThan %i, %five : i32

// CHECK: spirv.Branch ^bb3
spirv.Branch ^selection

// CHECK-NEXT: ^bb3:
^selection:
// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
// CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
spirv.BranchConditional %cond, ^true, ^merge
// CHECK-NEXT: ^bb1:
^true:
// CHECK-NEXT: spirv.mlir.break
spirv.mlir.break
// CHECK-NEXT: ^bb2:
^merge:
// CHECK-NEXT: spirv.mlir.merge
spirv.mlir.merge
}

// CHECK: spirv.Branch ^bb4
spirv.Branch ^continue

// CHECK-NEXT: ^bb4:
^continue:
%new_i = spirv.IAdd %i, %one : i32
// CHECK: spirv.Branch ^bb1({{%.*}}: i32)
spirv.Branch ^header(%new_i: i32)

// CHECK-NEXT: ^bb5:
^merge:
// CHECK-NEXT: spirv.mlir.merge
spirv.mlir.merge
}

return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.mlir.break
//===----------------------------------------------------------------------===//

func.func @break() -> () {
// expected-error @+1 {{op expects to be nested in LoopOp}}
spirv.mlir.break
}

// -----

//===----------------------------------------------------------------------===//
// spirv.mlir.merge
//===----------------------------------------------------------------------===//
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Target/SPIRV/loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.Return
}
}

// -----

// Loop with break statement

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.func @loop_break(%count : i32) -> () "None" {
%zero = spirv.Constant 0: i32
%one = spirv.Constant 1: i32
%five = spirv.Constant 5: i32

// CHECK: spirv.mlir.loop {
spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
spirv.Branch ^header(%zero: i32)

// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
^header(%i : i32):
%cmp = spirv.SLessThan %i, %count : i32
// CHECK: spirv.BranchConditional {{%.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
spirv.BranchConditional %cmp, ^body, ^merge

// CHECK-NEXT: ^[[BODY:.+]]:
^body:
%cond = spirv.SGreaterThan %i, %five : i32

// CHECK: spirv.Branch ^[[LINK:.+]]
spirv.Branch ^selection

// COM: Artificial block introduced by block splitting in the deserializer.
// CHECK-NEXT: ^[[LINK:.+]]:
// CHECK-NEXT: spirv.Branch ^[[SELECTION:.+]]

// CHECK-NEXT: ^[[SELECTION:.+]]:
^selection:
// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
// CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^[[TRUE:.+]], ^[[FALSE:.+]]
spirv.BranchConditional %cond, ^true, ^merge
// CHECK-NEXT: ^[[TRUE:.+]]:
^true:
// CHECK-NEXT: spirv.mlir.break
spirv.mlir.break
// CHECK-NEXT: ^[[MERGE:.+]]:
^merge:
// CHECK-NEXT: spirv.mlir.merge
spirv.mlir.merge
}

// CHECK: spirv.Branch ^[[CONTINUE:.+]]
spirv.Branch ^continue

// CHECK-NEXT: ^[[CONTINUE:.+]]:
^continue:
%new_i = spirv.IAdd %i, %one : i32
// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
spirv.Branch ^header(%new_i: i32)

// CHECK-NEXT: ^[[MERGE:.+]]:
^merge:
// CHECK-NEXT: spirv.mlir.merge
spirv.mlir.merge
}

// CHECK: spirv.Return
spirv.Return
}
}