From 8ac954fea0644fc6a215fe1383a1ece57a3f7981 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Wed, 4 Jun 2025 00:23:32 +0200 Subject: [PATCH 1/3] Refactor: update_graph_outputs in a helper (#62) Signed-off-by: Johansmm --- src/onnx_ir/_convenience/__init__.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 45f740b..589009e 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: return values +def _update_graph_or_function_outputs( + graph_or_function: _core.Graph | _core.Function, + old_values: Sequence[_core.Value], + new_values: Sequence[_core.Value], +): + """Update graph/function outputs.""" + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + def replace_nodes_and_values( graph_or_function: _core.Graph | _core.Function, /, @@ -367,10 +379,7 @@ def replace_nodes_and_values( # Reconnect the users of the deleted values to use the new values replace_all_uses_with(old_values, new_values) # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + _update_graph_or_function_outputs(graph_or_function, old_values, new_values) # insert new nodes after the index node graph_or_function.insert_after(insertion_point, new_nodes) From eb43a2f400dde7b35f5601299003757c2852d986 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Wed, 4 Jun 2025 00:28:29 +0200 Subject: [PATCH 2/3] Introduce insert_nodes_before_value (#62) Convenience function to insert a set of nodes in value(s). Signed-off-by: Johansmm --- src/onnx_ir/_convenience/__init__.py | 106 ++++++++++++++++++ src/onnx_ir/_convenience/_init_test.py | 146 +++++++++++++++++++++++++ src/onnx_ir/convenience.py | 2 + 3 files changed, 254 insertions(+) create mode 100644 src/onnx_ir/_convenience/_init_test.py diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 589009e..bcbc7bd 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -14,6 +14,7 @@ "replace_all_uses_with", "create_value_mapping", "replace_nodes_and_values", + "insert_nodes_in_value", ] from collections.abc import Mapping, Sequence @@ -384,3 +385,108 @@ def replace_nodes_and_values( # insert new nodes after the index node graph_or_function.insert_after(insertion_point, new_nodes) graph_or_function.remove(old_nodes, safe=True) + + +def _find_inputs_outputs( + nodes: Sequence[_core.Node], +) -> tuple[tuple[_core.Value | None, ...], tuple[_core.Value, ...]]: + """Find the values that are considered as inputs and outputs in a sequence of nodes.""" + # Search the unique inputs/outputs in new_nodes, keeping the order. + all_inputs = dict.fromkeys(sum((node.inputs for node in nodes), ())) # type: ignore[type-var] + all_outputs = dict.fromkeys(sum((node.outputs for node in nodes), ())) # type: ignore[type-var] + # A value is considered as input if it is not any output. + inputs = tuple(val for val in all_inputs if val not in all_outputs) + # A value is considered as output if it is not any input. + outputs = tuple(val for val in all_outputs if val not in all_inputs) + return inputs, outputs + + +def insert_nodes_in_value( + values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node] +) -> None: + """Inserts a sequence of nodes into the provided value(s). + + This allows to insert a list of LINKED nodes (over the same context) at + a specific point in the graph. + + For example, suppose we have the following graph:: + + input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output + + We want to insert [node_M, node_N] at B value:: + + >>> import onnx_ir as ir + >>> input = ir.Input("input") + >>> node_A = ir.node("op_A", [input]) + >>> B = ir.Value(name="B") + >>> node_B = ir.node("op_B", node_A.outputs, outputs=[B]) + >>> node_C = ir.node("op_C", node_B.outputs) + >>> # Create a new sequence to insert + >>> input_2 = ir.Input("input_2") + >>> node_M = ir.node("op_M", [input_2]) + >>> node_N = ir.node("op_N", node_M.outputs) + >>> # Insert nodes in B + >>> insert_nodes_in_value(node_B.outputs, [node_M, node_N]) + >>> len(node_B.outputs) + 1 + >>> node_B.outputs[0].consumers()[0].op_type + 'op_M' + >>> len(node_C.inputs) + 1 + >>> node_C.inputs[0].producer().op_type + 'op_N' + >>> node_C.inputs[0].name + 'B' + + When values is a sequence, the set of nodes must have the same number + of inputs and outputs, then they are zipped into pairs: first value is + replaced with the first input/output, and so on. + + Args: + values: The value(s) where to insert the nodes. + new_nodes: The nodes to insert in the graph. + """ + if not isinstance(values, Sequence): + values = (values,) + + # Search the unique inputs/outputs in new_nodes, keeping the order. + inputs, outputs = _find_inputs_outputs(new_nodes) + + # Sanity check. + if len(values) != len(inputs): + raise ValueError( + f"The number of values and inputs ({inputs}) in new_nodes must match." + ) + if len(values) != len(outputs): + raise ValueError( + f"The number of values and outputs ({outputs}) in new_nodes must match." + ) + + # Propagate relevant info. + for val, in_val, out_val in zip(values, inputs, outputs): + # Propagate relevant info from value to out_value. + # TODO(Rama): Perhaps this should be a separate utility function. + out_val.type = val.type + out_val.shape = val.shape + out_val.name = val.name + # Propagate relevant info from value to in_value. + # TODO(Rama): Perhaps this should be a separate utility function. + in_val.type = val.type + in_val.shape = val.shape + # Rename each value, following each input. + val.name = in_val.name + + # Insert the new nodes in two steps: + # 1. Reconnect the users of values to the outputs + replace_all_uses_with(values, outputs) + # 2. Reconnect the users of inputs to values + replace_all_uses_with(inputs, values) + + # Update graph if there is one: + if (graph := values[-1].graph) is not None: + # Update graph/function outputs if the node generates output + _update_graph_or_function_outputs(graph, values, outputs) + + # Insert new nodes if there is a graph + graph.extend(new_nodes) + graph.sort() diff --git a/src/onnx_ir/_convenience/_init_test.py b/src/onnx_ir/_convenience/_init_test.py new file mode 100644 index 0000000..b86162a --- /dev/null +++ b/src/onnx_ir/_convenience/_init_test.py @@ -0,0 +1,146 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the _convenience module.""" + +import unittest + +import onnx + +import onnx_ir as ir +from onnx_ir._convenience import insert_nodes_in_value + + +def _create_model(model_text: str) -> ir.Model: + model = onnx.parser.parse_model(model_text) + return ir.serde.deserialize_model(model) + + +class ConvenienceTest(unittest.TestCase): + def test_insert_nodes_in_value(self): + # Main graph + input = ir.Input("input") + node_A = ir.node("op_A", [input]) + node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")]) + node_C = ir.node("op_C", node_B.outputs) + + # New sequence to insert + input_2 = ir.Input("input_2") + node_M = ir.node("op_M", [input_2]) + node_N = ir.node("op_N", node_M.outputs) + + # Insert nodes in B + insert_nodes_in_value(node_B.outputs[0], [node_M, node_N]) + self.assertEqual(len(node_B.outputs), 1) + self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M") + self.assertEqual(len(node_C.inputs), 1) + self.assertEqual(node_C.inputs[0].producer().op_type, "op_N") + self.assertEqual(node_C.inputs[0].name, "B") + + def test_insert_nodes_in_value_in_graph(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + + # Sequence to insert. + # Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]]. + i1, i2 = ir.Input("i1"), ir.Input("i2") + a = ir.node("op_1", [i1, i2]) + b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2) + c = ir.node("op_3", [i2, b.outputs[0]]) + + # Insert nodes in SplitNode.outputs + target_node = ir_model.graph[1] + insert_nodes_in_value(target_node.outputs, [a, b, c]) + + # Check target_node outputs have been renamed + new_i1, new_i2 = target_node.outputs + self.assertEqual(new_i1.name, "i1") + self.assertEqual(new_i2.name, "i2") + + # Check i1 and i2 have new users + self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2")) + self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3")) + + # Check outputs have been correctly renamed as previous values + self.assertEqual(b.outputs[1].name, "a") + self.assertEqual(c.outputs[0].name, "b") + + # Check nodes have been inserted in the graph + self.assertEqual(len(ir_model.graph), 6) + + def test_insert_nodes_in_input(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + z = Add(x, two) + } + """ + ) + + # Sequence to insert. + x = ir.Input("new_x") + node = ir.node("Mul", [x, x]) + + # Insert nodes in graph.inputs + insert_nodes_in_value(ir_model.graph[1].inputs[0], [node]) + self.assertEqual(node.outputs[0].name, "x") + + # Check input has been renamed + self.assertEqual(ir_model.graph.inputs[0].name, "new_x") + + # Finally, check new graph is valid + proto = ir.to_proto(ir_model) + onnx.checker.check_model(proto, full_check=True) + + def test_insert_nodes_in_output(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + z = Add(x, two) + } + """ + ) + + # Sequence to insert. + x = ir.Input("new_z") + node = ir.node("Mul", [x, x]) + + # Insert nodes in graph.inputs + insert_nodes_in_value(ir_model.graph.outputs[0], [node]) + self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z") + + # Check output name is preserved + self.assertEqual(ir_model.graph.outputs[0].name, "z") + + def test_value_error_for_wrong_number_of_points(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")]) + with self.assertRaisesRegex(ValueError, "The number of values and inputs"): + insert_nodes_in_value(ir_model.graph[0].outputs, [node]) + + with self.assertRaisesRegex(ValueError, "The number of values and outputs"): + insert_nodes_in_value(ir_model.graph[1].outputs, [node]) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/onnx_ir/convenience.py b/src/onnx_ir/convenience.py index 2d6bffc..6cd8b04 100644 --- a/src/onnx_ir/convenience.py +++ b/src/onnx_ir/convenience.py @@ -10,6 +10,7 @@ "replace_all_uses_with", "replace_nodes_and_values", "create_value_mapping", + "insert_nodes_in_value", ] from onnx_ir._convenience import ( @@ -18,6 +19,7 @@ create_value_mapping, replace_all_uses_with, replace_nodes_and_values, + insert_nodes_in_value, ) # NOTE: Do not implement any other functions in this module. From f077782b1ad45968d49171014f68c8bfce06187b Mon Sep 17 00:00:00 2001 From: Johansmm Date: Wed, 4 Jun 2025 00:32:19 +0200 Subject: [PATCH 3/3] Include insert_nodes_in_value in doc (#62) Signed-off-by: Johansmm --- docs/api/ir_convenience.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/ir_convenience.md b/docs/api/ir_convenience.md index 4cdfdec..95203a0 100644 --- a/docs/api/ir_convenience.md +++ b/docs/api/ir_convenience.md @@ -12,4 +12,5 @@ .. autofunction:: replace_all_uses_with .. autofunction:: replace_nodes_and_values .. autofunction:: create_value_mapping +.. autofunction:: insert_nodes_in_value ```