Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR for MoveLinearPastEltwiseMul transformation #1275

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/finn/transformation/streamline/reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,75 @@ def apply(self, model):
return (model, graph_modified)


class MoveLinearPastEltwiseMul(Transformation):
"""Move linear operations (mul) past elementwise mul operations where possible.
Specifically,matches and transforms the following patterns:
(x*A) * (y*B) -> (xy)*(A*B)
where x and y are dynamic inputs, A, B are constant tensors (in general).
"""

def move_node(self, graph, n, prod0, prod1, node_ind):
# found! move one of the muls to output, remove the other one
lin0_in0 = prod0.input[0]
lin1_in0 = prod1.input[0]
in0 = n.input[0]
out = n.output[0]
# connect the eltwise mul inputs to mul inputs
n.input[0] = lin0_in0
n.input[1] = lin1_in0
# connect mul0 output to eltwise mul output
prod0.output[0] = out
# connect the input of mul0 and output of eltwise mul together
n.output[0] = in0
prod0.input[0] = in0
# move prod0 node past eltwise mul node, and remove prod1
graph.node.remove(prod1)
graph.node.remove(prod0)
graph.node.insert(node_ind - 2, prod0)

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1
# checking if the operation is eltwisemul
if n.op_type == "Mul":
in0 = n.input[0]
in1 = n.input[1]
if in0 is None or in1 is None:
continue
A = model.get_initializer(in0)
B = model.get_initializer(in1)
if A is not None or B is not None:
continue
# check for mul with same initializer on both inputs
prod0 = model.find_producer(in0)
prod1 = model.find_producer(in1)
if prod0 is None or prod1 is None or (prod0 == prod1):
continue
if len(prod0.input) < 2 or len(prod1.input) < 2:
continue
init0 = model.get_initializer(prod0.input[1])
init1 = model.get_initializer(prod1.input[1])
# if either initializer is None, skip
if init0 is None or init1 is None:
continue
if prod0.op_type == "Mul" and prod1.op_type == "Mul":
# Adding the update intializer condition
init = init0 * init1
# update initializer of prod0, the node which will move
model.set_initializer(prod0.input[1], init)
self.move_node(graph, n, prod0, prod1, node_ind)
node_ind -= 1
graph_modified = True
else:
continue
model = model.transform(InferShapes())
return (model, graph_modified)


class MoveScalarLinearPastInvariants(Transformation):
"""Move scalar linear operations (mul, add) past functions which are invariant
to them. Specifically, matches and transforms the following patterns:
Expand Down