Skip to content

Commit 7dd644a

Browse files
committed
[Autocast] Fix edge case casting input directly to output
Update modelopt/onnx/autocast/precisionconverter.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <[email protected]> cleanup Signed-off-by: Ali Boubezari <[email protected]> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <[email protected]> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <[email protected]> move pass Signed-off-by: Ali Boubezari <[email protected]> call sanitizer in precision converter Signed-off-by: Ali Boubezari <[email protected]>
1 parent cf6f1d4 commit 7dd644a

File tree

3 files changed

+130
-0
lines changed

3 files changed

+130
-0
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def sanitize(self) -> None:
6363
self.ensure_graph_name_exists()
6464
onnx_utils.name_onnx_nodes(self.model.graph)
6565
self.replace_custom_domain_nodes()
66+
self.sanitize_io_casts()
6667
self.cleanup_model()
6768
self.set_ir_version(self.max_ir_version)
6869

@@ -322,6 +323,43 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
322323
logger.debug(f"Failed to match LayerNorm pattern at {mean_node.name}: {e!s}")
323324
return None
324325

326+
def sanitize_io_casts(self) -> None:
327+
"""Handle the special case where an input is casted directly to an output.
328+
329+
Inject an identity node after the cast node.
330+
"""
331+
model_input_names = {input.name for input in self.model.graph.input}
332+
model_output_names = {output.name for output in self.model.graph.output}
333+
nodes_to_add = []
334+
for node in self.model.graph.node:
335+
if (
336+
node.op_type == "Cast"
337+
and node.input
338+
and node.output
339+
and node.input[0] in model_input_names
340+
and node.output[0] in model_output_names
341+
):
342+
# Unique per graph output to avoid collisions when multiple outputs are cast from the same input
343+
cast_output_name = node.output[0]
344+
cast_new_output_name = f"{cast_output_name}__io_cast_src"
345+
nodes_to_add.append(
346+
helper.make_node(
347+
"Identity",
348+
inputs=[cast_new_output_name],
349+
outputs=[cast_output_name],
350+
name=f"{node.name}__io_cast_identity",
351+
)
352+
)
353+
# Rewire Cast to produce the new intermediate
354+
node.output[0] = cast_new_output_name
355+
356+
for node in nodes_to_add:
357+
self.model.graph.node.append(node)
358+
359+
# Make sure the graph is topologically sorted
360+
gs_graph = gs.import_onnx(self.model).cleanup().toposort()
361+
self.model = gs.export_onnx(gs_graph)
362+
325363
def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
326364
"""Create a LayerNormalization node with optional bias."""
327365
ln_name = f"LayerNorm_{pattern['mean_node'].name}"

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import modelopt.onnx.autocast.utils as utils
3434
import modelopt.onnx.utils as onnx_utils
35+
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
3536
from modelopt.onnx.autocast.logging_config import configure_logging, logger
3637

3738
configure_logging()
@@ -73,6 +74,9 @@ def __init__(
7374
low_precision_type: str = "fp16",
7475
init_conversion_max_bytes: int | None = None,
7576
custom_ops: set[str] | None = None,
77+
min_opset: int = 13,
78+
max_ir_version: int | None = None,
79+
trt_plugins: list[str] | None = [],
7680
) -> None:
7781
"""Initialize PrecisionConverter.
7882
@@ -109,6 +113,9 @@ def __init__(
109113
self.original_network_io.update(
110114
{io.name: io.type.tensor_type.elem_type for io in self.model.graph.output}
111115
)
116+
self.min_opset = min_opset
117+
self.max_ir_version = max_ir_version
118+
self.trt_plugins = trt_plugins
112119

113120
def convert(
114121
self,
@@ -132,6 +139,8 @@ def convert(
132139
"AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
133140
)
134141

142+
self._sanitize_model()
143+
135144
# Filter out nodes that are not allowed to be in low precision
136145
# This is done here and not in NodeClassifier because it is required for the model to be valid
137146
high_precision_nodes, low_precision_nodes = self._filter_unsupported_op_types(
@@ -1030,3 +1039,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
10301039
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
10311040
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
10321041
return False
1042+
1043+
def _sanitize_model(self):
1044+
graph_sanitizer = GraphSanitizer(
1045+
self.model,
1046+
self.min_opset,
1047+
trt_plugins=self.trt_plugins,
1048+
max_ir_version=self.max_ir_version,
1049+
)
1050+
graph_sanitizer.sanitize()
1051+
self.model = graph_sanitizer.model

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def low_precision_onnx_type(low_precision_type_str):
3030
return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16
3131

3232

33+
LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
34+
35+
3336
####################################################################################################
3437
# Testing with a basic GEMM->Add->Relu graph
3538
####################################################################################################
@@ -1023,3 +1026,73 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231026
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
10241027
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
10251028
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"
1029+
1030+
1031+
@pytest.fixture
1032+
def model_with_casted_input_to_output():
1033+
"""Create a model with an output produced by a Cast node."""
1034+
# Create input and outputs
1035+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
1036+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
1037+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output
1038+
1039+
# Create constant value
1040+
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1041+
1042+
# Create constant node
1043+
const_node = helper.make_node(
1044+
"Constant",
1045+
[],
1046+
["const"],
1047+
name="const",
1048+
value=numpy_helper.from_array(const, name="const_value"),
1049+
)
1050+
1051+
# Create computation nodes
1052+
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
1053+
add2 = helper.make_node("Add", ["add1_out", "const"], ["Y2"], name="add2")
1054+
1055+
# Create cast node that feeds directly from input to output
1056+
cast_input = helper.make_node("Cast", ["X"], ["Y1"], name="cast_input", to=TensorProto.FLOAT)
1057+
1058+
graph = helper.make_graph(
1059+
[const_node, add1, add2, cast_input],
1060+
"model_with_casted_output",
1061+
[x],
1062+
[y1, y2],
1063+
[],
1064+
)
1065+
1066+
model = helper.make_model(graph, producer_name="model_with_casted_output")
1067+
model.opset_import[0].version = 20
1068+
model.ir_version = 10
1069+
onnx.checker.check_model(model)
1070+
1071+
model = onnx_utils.infer_shapes(model)
1072+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1073+
1074+
return model, value_info_map, initializer_map, node_to_init_map
1075+
1076+
1077+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1078+
@pytest.mark.parametrize("keep_io_types", [True, False])
1079+
def test_casted_input_to_output_model(
1080+
model_with_casted_input_to_output, low_precision_type, keep_io_types
1081+
):
1082+
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output
1083+
1084+
converter = PrecisionConverter(
1085+
model,
1086+
value_info_map,
1087+
initializer_map,
1088+
node_to_init_map,
1089+
keep_io_types=keep_io_types,
1090+
low_precision_type=low_precision_type,
1091+
min_opset=22 if low_precision_type == "bf16" else 13,
1092+
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
1093+
trt_plugins=[],
1094+
)
1095+
converted_model = converter.convert(
1096+
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
1097+
)
1098+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)