-
Notifications
You must be signed in to change notification settings - Fork 9
Introduce insert_nodes convenience function #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
"replace_all_uses_with", | ||
"create_value_mapping", | ||
"replace_nodes_and_values", | ||
"insert_nodes_in_value", | ||
] | ||
|
||
from collections.abc import Mapping, Sequence | ||
|
@@ -336,6 +337,18 @@ | |
return values | ||
|
||
|
||
def _update_graph_or_function_outputs( | ||
graph_or_function: _core.Graph | _core.Function, | ||
old_values: Sequence[_core.Value], | ||
new_values: Sequence[_core.Value], | ||
): | ||
"""Update graph/function outputs.""" | ||
replacement_mapping = dict(zip(old_values, new_values)) | ||
for idx, graph_or_function_output in enumerate(graph_or_function.outputs): | ||
if graph_or_function_output in replacement_mapping: | ||
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] | ||
|
||
|
||
def replace_nodes_and_values( | ||
graph_or_function: _core.Graph | _core.Function, | ||
/, | ||
|
@@ -367,11 +380,113 @@ | |
# Reconnect the users of the deleted values to use the new values | ||
replace_all_uses_with(old_values, new_values) | ||
# Update graph/function outputs if the node generates output | ||
replacement_mapping = dict(zip(old_values, new_values)) | ||
for idx, graph_or_function_output in enumerate(graph_or_function.outputs): | ||
if graph_or_function_output in replacement_mapping: | ||
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] | ||
_update_graph_or_function_outputs(graph_or_function, old_values, new_values) | ||
|
||
# insert new nodes after the index node | ||
graph_or_function.insert_after(insertion_point, new_nodes) | ||
graph_or_function.remove(old_nodes, safe=True) | ||
|
||
|
||
def _find_inputs_outputs( | ||
nodes: Sequence[_core.Node], | ||
) -> tuple[tuple[_core.Value | None, ...], tuple[_core.Value, ...]]: | ||
"""Find the values that are considered as inputs and outputs in a sequence of nodes.""" | ||
# Search the unique inputs/outputs in new_nodes, keeping the order. | ||
all_inputs = dict.fromkeys(sum((node.inputs for node in nodes), ())) # type: ignore[type-var] | ||
all_outputs = dict.fromkeys(sum((node.outputs for node in nodes), ())) # type: ignore[type-var] | ||
# A value is considered as input if it is not any output. | ||
inputs = tuple(val for val in all_inputs if val not in all_outputs) | ||
# A value is considered as output if it is not any input. | ||
outputs = tuple(val for val in all_outputs if val not in all_inputs) | ||
return inputs, outputs | ||
|
||
|
||
def insert_nodes_in_value( | ||
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you share a use case for this function? I am trying to think what the best api for it should be. Right now the order of appearance of the node inputs implicitly matches the values, which may not be ideal. Also the function name may be improved to be more accurate and succinct. If we generalize this method, what should an “insert_nodes” api do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am thinking in some case that is needed to insert new nodes without remove previous ones. Following: microsoft/onnxscript#2064 discussion, this is not possible to do this with onnxscript.rewriter, since rewrite_pattern have to redefine target_pattern. Moreover I believe with this helper it is easer to include new nodes with something like for node in ir_model.graph:
if node.op_type == 'expected_op_type':
insert_node(node.outputs[0], [new_node]) I found this ease to understand/use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For your other comment (about order), do you have any suggestion ? When len(values) > 1, I though was a good idea to infer the inputs for the list of nodes keeping the order, being needed to change the position of nodes in the list if values need to be re-order. Maybe it would be enough if it is better specified in the use of the function, right? |
||
) -> None: | ||
"""Inserts a sequence of nodes into the provided value(s). | ||
|
||
This allows to insert a list of LINKED nodes (over the same context) at | ||
a specific point in the graph. | ||
|
||
For example, suppose we have the following graph:: | ||
|
||
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output | ||
|
||
We want to insert [node_M, node_N] at B value:: | ||
|
||
>>> import onnx_ir as ir | ||
>>> input = ir.Input("input") | ||
>>> node_A = ir.node("op_A", [input]) | ||
>>> B = ir.Value(name="B") | ||
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B]) | ||
>>> node_C = ir.node("op_C", node_B.outputs) | ||
>>> # Create a new sequence to insert | ||
>>> input_2 = ir.Input("input_2") | ||
>>> node_M = ir.node("op_M", [input_2]) | ||
>>> node_N = ir.node("op_N", node_M.outputs) | ||
>>> # Insert nodes in B | ||
>>> insert_nodes_in_value(node_B.outputs, [node_M, node_N]) | ||
>>> len(node_B.outputs) | ||
1 | ||
>>> node_B.outputs[0].consumers()[0].op_type | ||
'op_M' | ||
>>> len(node_C.inputs) | ||
1 | ||
>>> node_C.inputs[0].producer().op_type | ||
'op_N' | ||
>>> node_C.inputs[0].name | ||
'B' | ||
|
||
When values is a sequence, the set of nodes must have the same number | ||
of inputs and outputs, then they are zipped into pairs: first value is | ||
replaced with the first input/output, and so on. | ||
|
||
Args: | ||
values: The value(s) where to insert the nodes. | ||
new_nodes: The nodes to insert in the graph. | ||
""" | ||
if not isinstance(values, Sequence): | ||
values = (values,) | ||
|
||
# Search the unique inputs/outputs in new_nodes, keeping the order. | ||
inputs, outputs = _find_inputs_outputs(new_nodes) | ||
|
||
# Sanity check. | ||
if len(values) != len(inputs): | ||
raise ValueError( | ||
f"The number of values and inputs ({inputs}) in new_nodes must match." | ||
) | ||
if len(values) != len(outputs): | ||
raise ValueError( | ||
f"The number of values and outputs ({outputs}) in new_nodes must match." | ||
) | ||
|
||
# Propagate relevant info. | ||
for val, in_val, out_val in zip(values, inputs, outputs): | ||
# Propagate relevant info from value to out_value. | ||
# TODO(Rama): Perhaps this should be a separate utility function. | ||
out_val.type = val.type | ||
out_val.shape = val.shape | ||
out_val.name = val.name | ||
# Propagate relevant info from value to in_value. | ||
# TODO(Rama): Perhaps this should be a separate utility function. | ||
in_val.type = val.type | ||
Check failureCode scanning / lintrunner MYPY/union-attr Error
Item "None" of "Value | None" has no attribute "type"
To disable, use # type: ignore[union-attr]
|
||
in_val.shape = val.shape | ||
Check failureCode scanning / lintrunner MYPY/union-attr Error
Item "None" of "Value | None" has no attribute "shape"
To disable, use # type: ignore[union-attr]
|
||
# Rename each value, following each input. | ||
val.name = in_val.name | ||
Check failureCode scanning / lintrunner MYPY/union-attr Error
Item "None" of "Value | None" has no attribute "name"
To disable, use # type: ignore[union-attr]
|
||
|
||
# Insert the new nodes in two steps: | ||
# 1. Reconnect the users of values to the outputs | ||
replace_all_uses_with(values, outputs) | ||
# 2. Reconnect the users of inputs to values | ||
replace_all_uses_with(inputs, values) | ||
Check failureCode scanning / lintrunner MYPY/arg-type Error
Argument 1 to "replace_all_uses_with" has incompatible type "tuple[Value | None, ...]"; expected "ValueProtocol | Sequence[ValueProtocol]"
To disable, use # type: ignore[arg-type]
|
||
|
||
# Update graph if there is one: | ||
if (graph := values[-1].graph) is not None: | ||
# Update graph/function outputs if the node generates output | ||
_update_graph_or_function_outputs(graph, values, outputs) | ||
|
||
# Insert new nodes if there is a graph | ||
graph.extend(new_nodes) | ||
graph.sort() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorting after insertion is not ideal. Is there a way to efficiently find the insertion point? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) ONNX Project Contributors | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Unit tests for the _convenience module.""" | ||
|
||
import unittest | ||
|
||
import onnx | ||
|
||
import onnx_ir as ir | ||
from onnx_ir._convenience import insert_nodes_in_value | ||
|
||
|
||
def _create_model(model_text: str) -> ir.Model: | ||
model = onnx.parser.parse_model(model_text) | ||
return ir.serde.deserialize_model(model) | ||
|
||
|
||
class ConvenienceTest(unittest.TestCase): | ||
def test_insert_nodes_in_value(self): | ||
# Main graph | ||
input = ir.Input("input") | ||
node_A = ir.node("op_A", [input]) | ||
node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")]) | ||
node_C = ir.node("op_C", node_B.outputs) | ||
|
||
# New sequence to insert | ||
input_2 = ir.Input("input_2") | ||
node_M = ir.node("op_M", [input_2]) | ||
node_N = ir.node("op_N", node_M.outputs) | ||
|
||
# Insert nodes in B | ||
insert_nodes_in_value(node_B.outputs[0], [node_M, node_N]) | ||
self.assertEqual(len(node_B.outputs), 1) | ||
self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M") | ||
self.assertEqual(len(node_C.inputs), 1) | ||
self.assertEqual(node_C.inputs[0].producer().op_type, "op_N") | ||
self.assertEqual(node_C.inputs[0].name, "B") | ||
|
||
def test_insert_nodes_in_value_in_graph(self): | ||
ir_model = _create_model( | ||
""" | ||
<ir_version: 10, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) { | ||
two = Constant<value_float=2.0>() | ||
a, b = SplitNode(x) | ||
z = MergeNode(a, b, two) | ||
} | ||
""" | ||
) | ||
|
||
# Sequence to insert. | ||
# Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]]. | ||
i1, i2 = ir.Input("i1"), ir.Input("i2") | ||
a = ir.node("op_1", [i1, i2]) | ||
b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2) | ||
c = ir.node("op_3", [i2, b.outputs[0]]) | ||
|
||
# Insert nodes in SplitNode.outputs | ||
target_node = ir_model.graph[1] | ||
insert_nodes_in_value(target_node.outputs, [a, b, c]) | ||
|
||
# Check target_node outputs have been renamed | ||
new_i1, new_i2 = target_node.outputs | ||
self.assertEqual(new_i1.name, "i1") | ||
self.assertEqual(new_i2.name, "i2") | ||
|
||
# Check i1 and i2 have new users | ||
self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2")) | ||
self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3")) | ||
|
||
# Check outputs have been correctly renamed as previous values | ||
self.assertEqual(b.outputs[1].name, "a") | ||
self.assertEqual(c.outputs[0].name, "b") | ||
|
||
# Check nodes have been inserted in the graph | ||
self.assertEqual(len(ir_model.graph), 6) | ||
|
||
def test_insert_nodes_in_input(self): | ||
ir_model = _create_model( | ||
""" | ||
<ir_version: 10, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) { | ||
two = Constant<value_float=2.0>() | ||
z = Add(x, two) | ||
} | ||
""" | ||
) | ||
|
||
# Sequence to insert. | ||
x = ir.Input("new_x") | ||
node = ir.node("Mul", [x, x]) | ||
|
||
# Insert nodes in graph.inputs | ||
insert_nodes_in_value(ir_model.graph[1].inputs[0], [node]) | ||
self.assertEqual(node.outputs[0].name, "x") | ||
|
||
# Check input has been renamed | ||
self.assertEqual(ir_model.graph.inputs[0].name, "new_x") | ||
|
||
# Finally, check new graph is valid | ||
proto = ir.to_proto(ir_model) | ||
onnx.checker.check_model(proto, full_check=True) | ||
|
||
def test_insert_nodes_in_output(self): | ||
ir_model = _create_model( | ||
""" | ||
<ir_version: 10, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) { | ||
two = Constant<value_float=2.0>() | ||
z = Add(x, two) | ||
} | ||
""" | ||
) | ||
|
||
# Sequence to insert. | ||
x = ir.Input("new_z") | ||
node = ir.node("Mul", [x, x]) | ||
|
||
# Insert nodes in graph.inputs | ||
insert_nodes_in_value(ir_model.graph.outputs[0], [node]) | ||
self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z") | ||
|
||
# Check output name is preserved | ||
self.assertEqual(ir_model.graph.outputs[0].name, "z") | ||
|
||
def test_value_error_for_wrong_number_of_points(self): | ||
ir_model = _create_model( | ||
""" | ||
<ir_version: 10, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) { | ||
two = Constant<value_float=2.0>() | ||
a, b = SplitNode(x) | ||
z = MergeNode(a, b, two) | ||
} | ||
""" | ||
) | ||
node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")]) | ||
with self.assertRaisesRegex(ValueError, "The number of values and inputs"): | ||
insert_nodes_in_value(ir_model.graph[0].outputs, [node]) | ||
|
||
with self.assertRaisesRegex(ValueError, "The number of values and outputs"): | ||
insert_nodes_in_value(ir_model.graph[1].outputs, [node]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.