-
Notifications
You must be signed in to change notification settings - Fork 168
[Autocast] Fix edge case casting input directly to output #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Autocast] Fix edge case casting input directly to output #305
Conversation
WalkthroughAdds GraphSanitizer.sanitize_io_casts and invokes it from sanitize() to isolate Casts that wire graph inputs directly to outputs by inserting Identity nodes; introduces duplicate sanitize_io_casts definitions. Adds a fixture and parameterized test for casted-input models. Extends PrecisionConverter to accept min_opset, max_ir_version, trt_plugins and run a sanitize pass before conversion. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as Unit Test
participant PC as PrecisionConverter
participant GS as GraphSanitizer
participant G as ONNX Graph
Test->>PC: construct PrecisionConverter(..., min_opset, max_ir_version, trt_plugins)
PC->>PC: store params
PC->>GS: _sanitize_model(self.model)
GS->>G: inspect nodes/topology
alt Cast node input=graph_input and output=graph_output
GS->>G: create Identity node, rewire Cast output -> Identity -> original output
GS->>G: re-topologize graph
GS-->>PC: return sanitized model
else
GS-->>PC: return model unchanged
end
PC->>PC: run suitability checks and conversion
PC-->>Test: return converted model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
624-634
: Preexisting-cast removal misses BF16 targetsComment says FP16/BF16/FP32 casts are removed, but
is_fp_cast
only matchesto ∈ {FLOAT16, FLOAT}
. Include BF16 as a target to honor the contract.- is_fp_cast = cast_to_type in [ - onnx.TensorProto.FLOAT16, - onnx.TensorProto.FLOAT, - ] and cast_from_type in [ + is_fp_cast = cast_to_type in [ + onnx.TensorProto.FLOAT16, + onnx.TensorProto.FLOAT, + onnx.TensorProto.BFLOAT16, + ] and cast_from_type in [ onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT, onnx.TensorProto.BFLOAT16, ]
641-644
: Guard for output-producing casts is ineffectiveThis condition checks if BOTH the cast input and output are network outputs, which never happens. It should keep casts that produce a network output.
- # Keep cast nodes that are necessary producers of network outputs - if any(node.input[0] == out.name for out in self.model.graph.output) and any( - node.output[0] == out.name for out in self.model.graph.output - ): + # Keep casts that produce a network output + if node.output[0] in model_output_names: continue
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)
618-621
: Insert duplicate IO-bridge casts deterministically (top of graph)Appending at the tail can shuffle topo order. Inserting at index 0 is more stable for input-driven casts.
- for cast in casts_to_add: - self.model.graph.node.append(cast) + for cast in casts_to_add: + self.model.graph.node.insert(0, cast)tests/unit/onnx/autocast/test_precisionconverter.py (2)
1068-1071
: Don’t write artifacts to /tmp in unit testsThe saved model isn’t used. Remove to keep tests hermetic.
- onnx.save(model, "/tmp/model_with_casted_output.onnx")
1076-1091
: Strengthen assertions: verify Y1 connectivity and dtypeOnly checking
onnx.checker
is weak. Assert that Y1 remains produced by a Cast and retains FP32 (keep_io_types=True).converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] ) - onnx.checker.check_model(converted_model) + onnx.checker.check_model(converted_model) + # Y1 should remain connected and produced by a Cast + y1_producers = utils.get_producer_nodes(converted_model, "Y1") + assert len(y1_producers) == 1 + assert y1_producers[0].op_type == "Cast" + # keep_io_types=True -> FP32 I/O preserved + y1_vi = next(y for y in converted_model.graph.output if y.name == "Y1") + assert y1_vi.type.tensor_type.elem_type == TensorProto.FLOAT
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py
(1 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model
(557-569)modelopt/onnx/autocast/precisionconverter.py (1)
convert
(113-202)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
607-617
: Confirm intent: skipping only the duplicated “new_cast”, not the renamed original
casts_to_skip
holds the original name (now assigned tonew_cast
) and will not skip the renamed original (..._io_special_case
). If that original gets removed,_bypass_cast_node
will reconnect its consumers directly to the model input. Is that intended? If not, add the renamed name to the skip list as well right after renaming.- casts_to_skip.append(node.name) + casts_to_skip.append(node.name) casts_to_add.append(new_cast) # Now adjust the old cast's name, consumers and producers node.name = f"{node.name}_io_special_case" + casts_to_skip.append(node.name) # keep the internal IO-special-case cast as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (3)
600-606
: Also preserve optional Cast attributes (e.g., saturate) to avoid semantic driftIf the model uses newer opsets where Cast may carry optional attributes (like saturate), the duplicate should copy them.
Apply:
- new_cast = helper.make_node( - "Cast", - name=node.name, - inputs=[node.input[0]], - outputs=[node.output[0]], - to=utils.get_cast_to_type(node), - ) + # Copy optional attributes (e.g., 'saturate' in newer opsets) + saturate = next((a.i for a in node.attribute if a.name == "saturate"), None) + cast_attrs = {"to": utils.get_cast_to_type(node)} + if saturate is not None: + cast_attrs["saturate"] = saturate + new_cast = helper.make_node( + "Cast", + name=node.name, + inputs=[node.input[0]], + outputs=[node.output[0]], + **cast_attrs, + )
618-621
: Insert duplicate Cast adjacent to the original for better locality and readabilityAppending at the end works but scatters IO nodes. Insert near the renamed source node to keep topology readable.
- for cast in casts_to_add: - self.model.graph.node.append(cast) + # Preserve locality: insert duplicates next to their originals + for cast in casts_to_add: + target_idx = -1 + for i, n in enumerate(self.model.graph.node): + if n.name == f"{cast.name}_io_special_case": + target_idx = i + break + if target_idx >= 0: + self.model.graph.node.insert(target_idx, cast) + else: + # Fallback to prepend to avoid end-append reordering + self.model.graph.node.insert(0, cast)
592-596
: Use a set for casts_to_skip from the startMinor nit for clarity and O(1) membership checks.
- casts_to_skip = [] + casts_to_skip: set[str] = set() # Add casts as a separate step to avoid modifying the graph while iterating over it casts_to_add = [] @@ - casts_to_skip.append(node.name) + casts_to_skip.add(node.name) @@ - casts_to_skip = set(casts_to_skip) + # already a setAlso applies to: 620-621
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/autocast/precisionconverter.py
(1 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
588-617
: Solid IO-cast preservation strategyDuplicating the IO-facing Cast, renaming the original, and rewiring consumers avoids disconnecting outputs while still enabling generic cast cleanup. This addresses the edge case cleanly.
value_info_map, | ||
initializer_map, | ||
node_to_init_map, | ||
keep_io_types=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test fails when keep_io_types=False
due to a graph input becoming a graph output directly, which violates the assertion in ModelOpt that all original input and output names should be maintained in the quantized model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed an edge case that challenges AutoCast's assumptions.
I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node.
@aboubezari @gcunhase
@galagam agree! @aboubezari, please make the suggested modification as a next step for this MR. Thanks! |
Will give this a try. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
1078-1097
: Recompute mappings after sanitize() and assert outputs preservedsanitize() mutates the graph (adds Identity and rewires). Recompute mappings to avoid stale references, and assert that both outputs remain present.
Apply this diff:
def test_casted_input_to_output_model( model_with_casted_input_to_output, low_precision_type, keep_io_types ): model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output min_opset = 22 if low_precision_type == "bf16" else 13 graph_sanitizer = GraphSanitizer(model, min_opset) graph_sanitizer.sanitize() - model = graph_sanitizer.model + model = graph_sanitizer.model + # Recompute mappings after graph mutation by sanitizer + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) converter = PrecisionConverter( model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, ) converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] ) onnx.checker.check_model(converted_model) + # Ensure both original outputs are preserved + assert {o.name for o in converted_model.graph.output} == {"Y1", "Y2"}Optionally, to validate type expectations on the direct-IO-cast output:
- If keep_io_types is True: Y1/Y2 should be FLOAT.
- If False: Y1/Y2 should be low_precision_onnx_type(low_precision_type).
1028-1067
: Optional: Expand fixture to cover multi-Cast-from-same-input edge caseTo guard against future regressions (multiple Casts from the same input to different outputs), add a companion fixture/test. Example snippet to append to this file:
@pytest.fixture def model_with_two_casted_outputs(): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) cast1 = helper.make_node("Cast", ["X"], ["Y1"], name="cast1", to=TensorProto.FLOAT) cast2 = helper.make_node("Cast", ["X"], ["Y2"], name="cast2", to=TensorProto.FLOAT) graph = helper.make_graph([cast1, cast2], "model_two_casted_outputs", [x], [y1, y2], []) model = helper.make_model(graph, producer_name="model_two_casted_outputs") model.opset_import[0].version = 20 model.ir_version = 10 onnx.checker.check_model(model) model = onnx_utils.infer_shapes(model) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) def test_two_casts_from_same_input(model_with_two_casted_outputs, keep_io_types, low_precision_type): model, value_info_map, initializer_map, node_to_init_map = model_with_two_casted_outputs gsani = GraphSanitizer(model, 22 if low_precision_type == "bf16" else 13) gsani.sanitize() model = gsani.model value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) conv = PrecisionConverter( model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type ) converted = conv.convert(high_precision_nodes=["cast1", "cast2"], low_precision_nodes=[]) onnx.checker.check_model(converted) assert {o.name for o in converted.graph.output} == {"Y1", "Y2"}I can send a PR update with these additions if helpful.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/graphsanitizer.py
(2 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer
(28-450)sanitize
(53-68)
🔇 Additional comments (2)
modelopt/onnx/autocast/graphsanitizer.py (1)
68-69
: Good placement in sanitize()Running sanitize_io_casts() after cleanup and IR versioning keeps the injected Identity from being swept early. LGTM.
tests/unit/onnx/autocast/test_precisionconverter.py (1)
23-24
: Import addition looks goodGraphSanitizer is used below; import is correct and scoped for tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
1077-1098
: Consider adding explicit assertions for the edge case.The test validates that the model passes
check_model
after conversion, which confirms basic correctness. However, it could be strengthened by explicitly asserting that:
- The Identity node was inserted by the sanitizer
- The original input and output names are preserved
- The Cast node is still present
This would make the test's intent clearer and provide better regression coverage for the specific edge case being addressed.
Example assertions to add before line 1098:
# Verify Identity node was inserted for the IO cast identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"] assert len(identity_nodes) >= 1, "Expected at least one Identity node for IO cast isolation" # Verify original I/O names are preserved assert converted_model.graph.input[0].name == "X" assert any(o.name == "Y1" for o in converted_model.graph.output) assert any(o.name == "Y2" for o in converted_model.graph.output)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py
(5 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model
(557-569)modelopt/onnx/autocast/precisionconverter.py (1)
convert
(120-211)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer
(28-460)sanitize
(53-68)
🔇 Additional comments (6)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
33-33
: LGTM! Constant definition is clear.The constant
LATEST_IR_VERSION_SUPPORTED_BY_ORT
is well-named and appropriately scoped for test usage.
1031-1074
: LGTM! Fixture correctly models the edge case.The fixture constructs a graph where a Cast node connects input X directly to output Y1, alongside a separate computation path. This accurately captures the edge case described in the PR objectives.
modelopt/onnx/autocast/precisionconverter.py (4)
35-35
: LGTM! Import is correctly placed.The GraphSanitizer import is appropriately added to support the new sanitization step.
77-79
: LGTM! New parameters are well-integrated.The three new parameters (
min_opset
,max_ir_version
,trt_plugins
) are properly added to the constructor with sensible defaults and stored as instance variables for use in the sanitization step.Also applies to: 116-118
142-142
: LGTM! Sanitization is correctly positioned in the workflow.The
_sanitize_model()
call is appropriately placed early in theconvert()
method, after model validation but before unsupported op filtering. This ensures the graph is normalized (e.g., IO casts are isolated with Identity nodes) before precision conversion logic executes.
1043-1051
: LGTM! Sanitization method is clean and correct.The
_sanitize_model()
method correctly instantiatesGraphSanitizer
with the model and relevant parameters, invokes sanitization, and updates the model. This cleanly integrates the graph normalization step (including IO cast isolation with Identity nodes) into the precision conversion workflow.
0e0d11a
to
7dd644a
Compare
Update modelopt/onnx/autocast/precisionconverter.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <[email protected]> cleanup Signed-off-by: Ali Boubezari <[email protected]> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <[email protected]> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <[email protected]> move pass Signed-off-by: Ali Boubezari <[email protected]> call sanitizer in precision converter Signed-off-by: Ali Boubezari <[email protected]>
7dd644a
to
4028e2a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
77-92
: Document the new parameters in the docstring.The three new parameters (
min_opset
,max_ir_version
,trt_plugins
) are not documented in the docstring.Apply this diff:
Args: model: ONNX model to convert. value_info_map: Map of tensor names to value info. initializer_map: Map of tensor names to initializers. node_to_init_map: Map of node names to lists of initializer names. keep_io_types: Keep the input and output types of the model, otherwise they will be converted. low_precision_type: Precision to convert to. init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be cast at runtime. custom_ops: List of custom ops. + min_opset: Minimum opset version to use for sanitization (default: 13). + max_ir_version: Maximum IR version supported by ORT (default: None). + trt_plugins: List of TensorRT plugin library paths in .so format (default: []).
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
1063-1071
: Consider matching parameter order to the GraphSanitizer constructor signature.While using keyword arguments makes order irrelevant, matching the constructor's parameter order improves readability.
Apply this diff:
def _sanitize_model(self): graph_sanitizer = GraphSanitizer( self.model, self.min_opset, - trt_plugins=self.trt_plugins, max_ir_version=self.max_ir_version, + trt_plugins=self.trt_plugins if self.trt_plugins is not None else [], ) graph_sanitizer.sanitize() self.model = graph_sanitizer.model
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/autocast/graphsanitizer.py
(2 hunks)modelopt/onnx/autocast/precisionconverter.py
(5 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unit/onnx/autocast/test_precisionconverter.py
- modelopt/onnx/autocast/graphsanitizer.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer
(28-560)sanitize
(53-69)
🔇 Additional comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
35-35
: LGTM!The import is correctly placed and necessary for the sanitization functionality.
142-143
: LGTM!The sanitization is correctly invoked early in the conversion workflow, after validation but before precision conversion operations.
min_opset: int = 13, | ||
max_ir_version: int | None = None, | ||
trt_plugins: list[str] | None = [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace mutable default argument with None.
The default value trt_plugins=[]
creates a mutable default argument, which can lead to shared state between instances if the list is mutated.
Apply this diff:
- trt_plugins: list[str] | None = [],
+ trt_plugins: list[str] | None = None,
Then update line 1067 to handle the None case:
graph_sanitizer = GraphSanitizer(
self.model,
self.min_opset,
- trt_plugins=self.trt_plugins,
+ trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
max_ir_version=self.max_ir_version,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
min_opset: int = 13, | |
max_ir_version: int | None = None, | |
trt_plugins: list[str] | None = [], | |
def __init__( | |
..., | |
min_opset: int = 13, | |
max_ir_version: int | None = None, | |
trt_plugins: list[str] | None = None, | |
): | |
... |
min_opset: int = 13, | |
max_ir_version: int | None = None, | |
trt_plugins: list[str] | None = [], | |
graph_sanitizer = GraphSanitizer( | |
self.model, | |
self.min_opset, | |
trt_plugins=self.trt_plugins if self.trt_plugins is not None else [], | |
max_ir_version=self.max_ir_version, | |
) |
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 77 to 79, change
the function signature to use trt_plugins: list[str] | None = None instead of a
mutable default list, and then at line 1067 update the code to treat a None
value as an empty list (e.g., set local_trt_plugins = trt_plugins or [] before
using it) so any subsequent iterations or mutations operate on a fresh list
rather than a shared default.
What does this PR do?
Type of change: Bug fix
Overview: If there a cast node connecting an input directly to an output, then the output will be totally disconnected due to naming issues. This fix will create specialized cast nodes for such edge cases and avoid removing them in the initial pass.
Usage
Autocast precision converter
Testing
Added a unittest that fails before my change, and passes after my fix.
Before your PR is "Ready for review"
Summary by CodeRabbit
Bug Fixes
New Features
Tests