Skip to content

Introduce insert_nodes convenience function #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/ir_convenience.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
.. autofunction:: replace_all_uses_with
.. autofunction:: replace_nodes_and_values
.. autofunction:: create_value_mapping
.. autofunction:: insert_nodes_in_value
```
123 changes: 119 additions & 4 deletions src/onnx_ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"replace_all_uses_with",
"create_value_mapping",
"replace_nodes_and_values",
"insert_nodes_in_value",
]

from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -336,6 +337,18 @@
return values


def _update_graph_or_function_outputs(
graph_or_function: _core.Graph | _core.Function,
old_values: Sequence[_core.Value],
new_values: Sequence[_core.Value],
):
"""Update graph/function outputs."""
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]


def replace_nodes_and_values(
graph_or_function: _core.Graph | _core.Function,
/,
Expand Down Expand Up @@ -367,11 +380,113 @@
# Reconnect the users of the deleted values to use the new values
replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)

# insert new nodes after the index node
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.remove(old_nodes, safe=True)


def _find_inputs_outputs(
nodes: Sequence[_core.Node],
) -> tuple[tuple[_core.Value | None, ...], tuple[_core.Value, ...]]:
"""Find the values that are considered as inputs and outputs in a sequence of nodes."""
# Search the unique inputs/outputs in new_nodes, keeping the order.
all_inputs = dict.fromkeys(sum((node.inputs for node in nodes), ())) # type: ignore[type-var]
all_outputs = dict.fromkeys(sum((node.outputs for node in nodes), ())) # type: ignore[type-var]
# A value is considered as input if it is not any output.
inputs = tuple(val for val in all_inputs if val not in all_outputs)
# A value is considered as output if it is not any input.
outputs = tuple(val for val in all_outputs if val not in all_inputs)
return inputs, outputs


def insert_nodes_in_value(
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share a use case for this function? I am trying to think what the best api for it should be. Right now the order of appearance of the node inputs implicitly matches the values, which may not be ideal. Also the function name may be improved to be more accurate and succinct.

If we generalize this method, what should an “insert_nodes” api do?

Copy link
Contributor Author

@Johansmm Johansmm Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking in some case that is needed to insert new nodes without remove previous ones. Following: microsoft/onnxscript#2064 discussion, this is not possible to do this with onnxscript.rewriter, since rewrite_pattern have to redefine target_pattern.

Moreover I believe with this helper it is easer to include new nodes with something like

for node in ir_model.graph:
    if node.op_type == 'expected_op_type':
        insert_node(node.outputs[0], [new_node])

I found this ease to understand/use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your other comment (about order), do you have any suggestion ? When len(values) > 1, I though was a good idea to infer the inputs for the list of nodes keeping the order, being needed to change the position of nodes in the list if values need to be re-order.

Maybe it would be enough if it is better specified in the use of the function, right?

) -> None:
"""Inserts a sequence of nodes into the provided value(s).

This allows to insert a list of LINKED nodes (over the same context) at
a specific point in the graph.

For example, suppose we have the following graph::

input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output

We want to insert [node_M, node_N] at B value::

>>> import onnx_ir as ir
>>> input = ir.Input("input")
>>> node_A = ir.node("op_A", [input])
>>> B = ir.Value(name="B")
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B])
>>> node_C = ir.node("op_C", node_B.outputs)
>>> # Create a new sequence to insert
>>> input_2 = ir.Input("input_2")
>>> node_M = ir.node("op_M", [input_2])
>>> node_N = ir.node("op_N", node_M.outputs)
>>> # Insert nodes in B
>>> insert_nodes_in_value(node_B.outputs, [node_M, node_N])
>>> len(node_B.outputs)
1
>>> node_B.outputs[0].consumers()[0].op_type
'op_M'
>>> len(node_C.inputs)
1
>>> node_C.inputs[0].producer().op_type
'op_N'
>>> node_C.inputs[0].name
'B'

When values is a sequence, the set of nodes must have the same number
of inputs and outputs, then they are zipped into pairs: first value is
replaced with the first input/output, and so on.

Args:
values: The value(s) where to insert the nodes.
new_nodes: The nodes to insert in the graph.
"""
if not isinstance(values, Sequence):
values = (values,)

# Search the unique inputs/outputs in new_nodes, keeping the order.
inputs, outputs = _find_inputs_outputs(new_nodes)

# Sanity check.
if len(values) != len(inputs):
raise ValueError(
f"The number of values and inputs ({inputs}) in new_nodes must match."
)
if len(values) != len(outputs):
raise ValueError(
f"The number of values and outputs ({outputs}) in new_nodes must match."
)

# Propagate relevant info.
for val, in_val, out_val in zip(values, inputs, outputs):
# Propagate relevant info from value to out_value.
# TODO(Rama): Perhaps this should be a separate utility function.
out_val.type = val.type
out_val.shape = val.shape
out_val.name = val.name
# Propagate relevant info from value to in_value.
# TODO(Rama): Perhaps this should be a separate utility function.
in_val.type = val.type

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "type" To disable, use # type: ignore[union-attr]
in_val.shape = val.shape

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "shape" To disable, use # type: ignore[union-attr]
# Rename each value, following each input.
val.name = in_val.name

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "name" To disable, use # type: ignore[union-attr]

# Insert the new nodes in two steps:
# 1. Reconnect the users of values to the outputs
replace_all_uses_with(values, outputs)
# 2. Reconnect the users of inputs to values
replace_all_uses_with(inputs, values)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "replace_all_uses_with" has incompatible type "tuple[Value | None, ...]"; expected "ValueProtocol | Sequence[ValueProtocol]" To disable, use # type: ignore[arg-type]

# Update graph if there is one:
if (graph := values[-1].graph) is not None:
# Update graph/function outputs if the node generates output
_update_graph_or_function_outputs(graph, values, outputs)

# Insert new nodes if there is a graph
graph.extend(new_nodes)
graph.sort()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorting after insertion is not ideal. Is there a way to efficiently find the insertion point?

146 changes: 146 additions & 0 deletions src/onnx_ir/_convenience/_init_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for the _convenience module."""

import unittest

import onnx

import onnx_ir as ir
from onnx_ir._convenience import insert_nodes_in_value


def _create_model(model_text: str) -> ir.Model:
model = onnx.parser.parse_model(model_text)
return ir.serde.deserialize_model(model)


class ConvenienceTest(unittest.TestCase):
def test_insert_nodes_in_value(self):
# Main graph
input = ir.Input("input")
node_A = ir.node("op_A", [input])
node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")])
node_C = ir.node("op_C", node_B.outputs)

# New sequence to insert
input_2 = ir.Input("input_2")
node_M = ir.node("op_M", [input_2])
node_N = ir.node("op_N", node_M.outputs)

# Insert nodes in B
insert_nodes_in_value(node_B.outputs[0], [node_M, node_N])
self.assertEqual(len(node_B.outputs), 1)
self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M")
self.assertEqual(len(node_C.inputs), 1)
self.assertEqual(node_C.inputs[0].producer().op_type, "op_N")
self.assertEqual(node_C.inputs[0].name, "B")

def test_insert_nodes_in_value_in_graph(self):
ir_model = _create_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant<value_float=2.0>()
a, b = SplitNode(x)
z = MergeNode(a, b, two)
}
"""
)

# Sequence to insert.
# Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]].
i1, i2 = ir.Input("i1"), ir.Input("i2")
a = ir.node("op_1", [i1, i2])
b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2)
c = ir.node("op_3", [i2, b.outputs[0]])

# Insert nodes in SplitNode.outputs
target_node = ir_model.graph[1]
insert_nodes_in_value(target_node.outputs, [a, b, c])

# Check target_node outputs have been renamed
new_i1, new_i2 = target_node.outputs
self.assertEqual(new_i1.name, "i1")
self.assertEqual(new_i2.name, "i2")

# Check i1 and i2 have new users
self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2"))
self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3"))

# Check outputs have been correctly renamed as previous values
self.assertEqual(b.outputs[1].name, "a")
self.assertEqual(c.outputs[0].name, "b")

# Check nodes have been inserted in the graph
self.assertEqual(len(ir_model.graph), 6)

def test_insert_nodes_in_input(self):
ir_model = _create_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant<value_float=2.0>()
z = Add(x, two)
}
"""
)

# Sequence to insert.
x = ir.Input("new_x")
node = ir.node("Mul", [x, x])

# Insert nodes in graph.inputs
insert_nodes_in_value(ir_model.graph[1].inputs[0], [node])
self.assertEqual(node.outputs[0].name, "x")

# Check input has been renamed
self.assertEqual(ir_model.graph.inputs[0].name, "new_x")

# Finally, check new graph is valid
proto = ir.to_proto(ir_model)
onnx.checker.check_model(proto, full_check=True)

def test_insert_nodes_in_output(self):
ir_model = _create_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant<value_float=2.0>()
z = Add(x, two)
}
"""
)

# Sequence to insert.
x = ir.Input("new_z")
node = ir.node("Mul", [x, x])

# Insert nodes in graph.inputs
insert_nodes_in_value(ir_model.graph.outputs[0], [node])
self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z")

# Check output name is preserved
self.assertEqual(ir_model.graph.outputs[0].name, "z")

def test_value_error_for_wrong_number_of_points(self):
ir_model = _create_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant<value_float=2.0>()
a, b = SplitNode(x)
z = MergeNode(a, b, two)
}
"""
)
node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")])
with self.assertRaisesRegex(ValueError, "The number of values and inputs"):
insert_nodes_in_value(ir_model.graph[0].outputs, [node])

with self.assertRaisesRegex(ValueError, "The number of values and outputs"):
insert_nodes_in_value(ir_model.graph[1].outputs, [node])


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions src/onnx_ir/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"replace_all_uses_with",
"replace_nodes_and_values",
"create_value_mapping",
"insert_nodes_in_value",
]

from onnx_ir._convenience import (
Expand All @@ -18,6 +19,7 @@
create_value_mapping,
replace_all_uses_with,
replace_nodes_and_values,
insert_nodes_in_value,
)

# NOTE: Do not implement any other functions in this module.
Expand Down
Loading