Skip to content

frontend: support conditionals as scf.if statements #1571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,6 @@ repos:
rev: 24.10.1
hooks:
- id: pyink
language_version: python3.11
language_version: python3.12

exclude: patches/.*\.patch$
25 changes: 25 additions & 0 deletions frontend/conditional_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from heir import compile
from heir.mlir import I1, I64, Secret

from absl.testing import absltest # fmt: skip


class EndToEndTest(absltest.TestCase):

def test_cond(self):

@compile(debug=True)
def cond(a: Secret[I64], b: Secret[I1]):
result = 0
if b:
result = a
else:
result = 0
result = result + 1
return result

self.assertEqual(2, cond(2))


if __name__ == "__main__":
absltest.main()
158 changes: 132 additions & 26 deletions frontend/heir/mlir_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,52 @@ class Loop:
inits: list[ir.Var]


def get_ambient_vars(block):
# Get vars defined in the ambient scope.
vars = set()
for instr in block.body:
match type(instr):
case ir.Assign:
if type(instr.value) == ir.Global:
continue
# Defined before the loop block
if not instr.target.is_temp:
if instr.target.loc.line < block.loc.line:
vars.add(instr.target)
return vars


def get_vars_used_after(block, post_dominators, ssa_ir):
# Find any vars assigned in the block scopes used in the post dominators.
block_vars = set()
for instr in block.body:
match type(instr):
case ir.Assign:
if type(instr.value) == ir.Global:
continue
if not instr.target.is_temp:
block_vars.add(instr.target)

vars = set()
for id in post_dominators:
pblock = ssa_ir.blocks[id]
for instr in pblock.body:
match type(instr):
case ir.Assign:
if type(instr.value) == ir.Global:
continue
# Defined before the loop block
for var in instr.list_vars():
if var.is_temp:
continue
if var == instr.target:
continue
# var must be defined in the block
if var in block_vars:
vars.add(var)
return list(vars)


def build_loop_from_call(index, block_id, blocks, cfa):
body = blocks[block_id].body
# Build a loop from a range call starting at index
Expand All @@ -132,17 +178,8 @@ def build_loop_from_call(index, block_id, blocks, cfa):

for id in loop_body_blocks:
block = blocks[id]
for instr in blocks[id].body:
match type(instr):
case ir.Assign:
if type(instr.value) == ir.Global:
continue
if instr.target == iter_var:
continue
# not temp and defined outside before the loop block
if not instr.target.is_temp:
if instr.target.loc.line < block.loc.line:
inits.add(instr.target)
ambient_vars = get_ambient_vars(block)
[inits.add(var) for var in ambient_vars if var != iter_var]

return Loop(
header_id,
Expand Down Expand Up @@ -196,6 +233,7 @@ def __init__(self, ssa_ir, secret_args: list[int], typemap, return_types):
self.numba_names_to_ssa_var_names = {}
self.globals_map = {}
self.loops = {}
self.cfa = self.get_control_flow()

def get_control_flow(self):
bc = bytecode.ByteCode(self.ssa_ir.func_id)
Expand Down Expand Up @@ -229,22 +267,26 @@ def emit(self):

def emit_blocks(self):
blocks = self.ssa_ir.blocks
cfa = self.get_control_flow()
for block_id, block in blocks.items():
print(block_id)
for instr in block.body:
print("\t" + str(instr))

# collect loops and block header needs
block_ids_to_omit_header = set()
loop_entries = [list(l.entries)[0] for l in cfa.graph._loops.values()]
block_ids_to_omit_header.update(self.cfa.graph.backbone())
loop_entries = [list(l.entries)[0] for l in self.cfa.graph._loops.values()]
for entry_id in loop_entries:
block = blocks[entry_id]
for i in range(len(block.body)):
# Detect a range call
instr = block.body[i]
if is_start_of_loop(i, block.body, self.ssa_ir):
loop = build_loop_from_call(i, entry_id, blocks, cfa)
loop = build_loop_from_call(i, entry_id, blocks, self.cfa)
self.loops[instr.target] = loop
block_ids_to_omit_header.add(loop.header.next_id)

sorted_blocks = list(cfa.iterblocks())
sorted_blocks = list(self.cfa.iterblocks())
# first block doesn't require a block header
block_ids_to_omit_header.add(sorted_blocks[0].offset)
blocks_to_print = deque(
Expand Down Expand Up @@ -323,6 +365,9 @@ def get_name(self, var):
assert var.name in self.numba_names_to_ssa_var_names
return self.get_or_create_name(var)

def has_name(self, var):
return var.name in self.numba_names_to_ssa_var_names

def forward_name(self, from_var, to_var):
to_name = self.numba_names_to_ssa_var_names[to_var.name]
self.numba_names_to_ssa_var_names[from_var.name] = to_name
Expand Down Expand Up @@ -366,12 +411,20 @@ def emit_assign(self, assign):
self.forward_name(from_var=assign.target, to_var=assign.value.value)
return ""
case ir.Const():
name = self.get_or_create_name(assign.target)
return (
# if we reassign const, then forward the name
reassign = self.has_name(assign.target)
if reassign:
name = self.get_next_name()
else:
name = self.get_or_create_name(assign.target)
const_str = (
f"{name} = arith.constant {assign.value.value} :"
f" {mlirType(self.typemap.get(assign.target.name))}"
f" {mlirLoc(assign.loc)}"
)
if reassign:
self.forward_name_to_id(assign.target, name.strip("%"))
return const_str
case ir.Global():
self.globals_map[assign.target.name] = assign.value.name
return ""
Expand Down Expand Up @@ -410,6 +463,9 @@ def emit_ext_if_needed(self, lhs, rhs):
short, long = rhs, lhs

tmp = self.get_next_name()
import ipdb

ipdb.set_trace()
Comment on lines +466 to +468
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import ipdb
ipdb.set_trace()

ext = (
f"{tmp} = arith.extui {self.get_name(short)} : "
f"{mlirType(self.typemap.get(str(short)))} "
Expand Down Expand Up @@ -456,17 +512,58 @@ def emit_binop(self, binop):
raise NotImplementedError("Unsupported binop: " + binop.fn.__name__)

def emit_branch(self, branch, blocks_to_print):
# We must ensure that both branches jump to the next execution. Do not
# support early exits.
condvar = self.get_name(branch.cond)
branches = [branch.truebr, branch.falsebr]
branch_strs = [
f"cf.cond_br {condvar}, ^bb{branch.truebr}, ^bb{branch.falsebr}"
]

branch_blocks = {}
for _ in range(2):
block_id, branch_block = blocks_to_print.popleft()
assert block_id in branches
body_str = self.emit_block(branch_block, blocks_to_print)
block_header = f"^bb{block_id}:\n"
branch_strs.append(block_header + textwrap.indent(body_str, " "))
branch_blocks[block_id] = branch_block

postdoms = self.cfa.graph.post_dominators()[branch.truebr]
postdoms.remove(branch.truebr)
# Variables to return from scf.if statement
retvars = get_vars_used_after(
branch_blocks[branch.truebr], postdoms, self.ssa_ir
)
falseretvars = get_vars_used_after(
branch_blocks[branch.falsebr], postdoms, self.ssa_ir
)
if retvars != retvars:
# We should be able to handle this by taking the union of each.
raise ValueError(
"Currently only supports branches that modify the same vars"
)

assert branch.truebr in branch_blocks
previous_state = self.numba_names_to_ssa_var_names.copy()
true_str = self.emit_block(branch_blocks[branch.truebr], blocks_to_print)
true_vars = [self.get_name(i) for i in retvars]
self.numba_names_to_ssa_var_names.clear()
self.numba_names_to_ssa_var_names = previous_state # rewind state
false_str = self.emit_block(branch_blocks[branch.falsebr], blocks_to_print)
false_vars = [self.get_name(i) for i in retvars]

resulttypes = [mlirType(self.typemap.get(str(i))) for i in retvars]
true_str += f"\nscf.yield {", ".join(true_vars)} : {", ".join(resulttypes)}"
false_str += (
f"\nscf.yield {", ".join(false_vars)} : {", ".join(resulttypes)}"
)

branch_stmt = f"scf.if {condvar}"
resultvars = [self.forward_to_new_id(i) for i in retvars]
if retvars:
results = ", ".join(resultvars)
typestr = ", ".join([f"{ty}" for ty in resulttypes])
branch_stmt = f"{results} = {branch_stmt} -> ({typestr})"
branch_stmt += " {"
branch_strs = [branch_stmt]
branch_strs.append(textwrap.indent(true_str, " "))
branch_strs.append("} else {")
branch_strs.append(textwrap.indent(false_str, " "))
branch_strs.append("}")

return "\n".join(branch_strs)

def emit_var_or_int(self, var_or_int):
Expand Down Expand Up @@ -532,7 +629,16 @@ def emit_loop(self, target, blocks_to_print):
if type(instr) == ir.Assign and instr.target in self.loops:
raise NotImplementedError("Nested loops are not supported")

body_str = self.emit_block(loop_block, blocks_to_print)
body_str = ""
itvar = self.get_name(loop.header.phi_var)
it = self.ssa_ir.get_assignee(loop.header.phi_var)
var_name = self.get_or_create_name(it)
body_str += (
f"{var_name} = arith.index_cast {itvar} : index to"
f" {mlirType(self.typemap.get(it.name))}\n"
)
self.forward_name(loop.header.phi_var, it)
body_str += self.emit_block(loop_block, blocks_to_print)
if len(loop.inits) > 1:
# Yield the iter args
yield_vars = ", ".join([self.get_name(init) for init in loop.inits])
Expand Down
Loading