Skip to content

Conversation

aboubezari
Copy link

@aboubezari aboubezari commented Sep 9, 2025

What does this PR do?

Type of change: Bug fix

Overview: If there a cast node connecting an input directly to an output, then the output will be totally disconnected due to naming issues. This fix will create specialized cast nodes for such edge cases and avoid removing them in the initial pass.

Usage

Autocast precision converter

Testing

Added a unittest that fails before my change, and passes after my fix.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Summary by CodeRabbit

  • Bug Fixes

    • Improved graph sanitization to detect casts from model inputs directly to outputs and insert intermediates, preserving valid I/O and reducing conversion errors for FP16/BF16 workflows.
  • New Features

    • Precision conversion now runs a sanitization pass before conversion and exposes configuration options for opset, IR version, and plugin handling.
  • Tests

    • Added unit tests and a fixture covering the casted input→output scenario, parameterized for low‑precision types and I/O preservation.

@aboubezari aboubezari requested a review from a team as a code owner September 9, 2025 00:47
@aboubezari aboubezari requested a review from ajrasane September 9, 2025 00:47
Copy link

copy-pr-bot bot commented Sep 9, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds GraphSanitizer.sanitize_io_casts and invokes it from sanitize() to isolate Casts that wire graph inputs directly to outputs by inserting Identity nodes; introduces duplicate sanitize_io_casts definitions. Adds a fixture and parameterized test for casted-input models. Extends PrecisionConverter to accept min_opset, max_ir_version, trt_plugins and run a sanitize pass before conversion.

Changes

Cohort / File(s) Summary
Graph sanitization
modelopt/onnx/autocast/graphsanitizer.py
Adds sanitize_io_casts() to detect Cast nodes whose input is a graph input and output is a graph output, insert an intermediate Identity node, rewire outputs, and re-topologize. Invokes sanitize_io_casts() from sanitize(). The patch contains duplicate sanitize_io_casts definitions within the class.
Precision conversion
modelopt/onnx/autocast/precisionconverter.py
Imports GraphSanitizer and runs a sanitization pass before conversion. Constructor signature extended to accept min_opset: int = 13, `max_ir_version: int
ONNX autocast tests
tests/unit/onnx/autocast/test_precisionconverter.py
Adds constant LATEST_IR_VERSION_SUPPORTED_BY_ORT. Adds fixture model_with_casted_input_to_output() returning (model, value_info_map, initializer_map, node_to_init_map). Adds parametrized test test_casted_input_to_output_model(...) covering low_precision_type ("fp16"/"bf16") and keep_io_types (True/False) that constructs PrecisionConverter with max_ir_version and runs conversion/asserts validity.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Test as Unit Test
  participant PC as PrecisionConverter
  participant GS as GraphSanitizer
  participant G as ONNX Graph

  Test->>PC: construct PrecisionConverter(..., min_opset, max_ir_version, trt_plugins)
  PC->>PC: store params
  PC->>GS: _sanitize_model(self.model)
  GS->>G: inspect nodes/topology
  alt Cast node input=graph_input and output=graph_output
    GS->>G: create Identity node, rewire Cast output -> Identity -> original output
    GS->>G: re-topologize graph
    GS-->>PC: return sanitized model
  else
    GS-->>PC: return model unchanged
  end
  PC->>PC: run suitability checks and conversion
  PC-->>Test: return converted model
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I nibbled wires where Casts did meet,
I slipped an Identity in to keep things neat.
fp16 and bf16 twirled with glee,
Sanitized hops set outputs free.
— a rabbit, cheering on the CI 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title “[Autocast] Fix edge case casting input directly to output” clearly identifies the affected component (Autocast) and succinctly summarizes the primary change—addressing the special-case Cast node between an input and an output—so it accurately reflects the pull request’s main objective without extraneous detail.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

624-634: Preexisting-cast removal misses BF16 targets

Comment says FP16/BF16/FP32 casts are removed, but is_fp_cast only matches to ∈ {FLOAT16, FLOAT}. Include BF16 as a target to honor the contract.

-                is_fp_cast = cast_to_type in [
-                    onnx.TensorProto.FLOAT16,
-                    onnx.TensorProto.FLOAT,
-                ] and cast_from_type in [
+                is_fp_cast = cast_to_type in [
+                    onnx.TensorProto.FLOAT16,
+                    onnx.TensorProto.FLOAT,
+                    onnx.TensorProto.BFLOAT16,
+                ] and cast_from_type in [
                     onnx.TensorProto.FLOAT16,
                     onnx.TensorProto.FLOAT,
                     onnx.TensorProto.BFLOAT16,
                 ]

641-644: Guard for output-producing casts is ineffective

This condition checks if BOTH the cast input and output are network outputs, which never happens. It should keep casts that produce a network output.

-                    # Keep cast nodes that are necessary producers of network outputs
-                    if any(node.input[0] == out.name for out in self.model.graph.output) and any(
-                        node.output[0] == out.name for out in self.model.graph.output
-                    ):
+                    # Keep casts that produce a network output
+                    if node.output[0] in model_output_names:
                         continue
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)

618-621: Insert duplicate IO-bridge casts deterministically (top of graph)

Appending at the tail can shuffle topo order. Inserting at index 0 is more stable for input-driven casts.

-        for cast in casts_to_add:
-            self.model.graph.node.append(cast)
+        for cast in casts_to_add:
+            self.model.graph.node.insert(0, cast)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

1068-1071: Don’t write artifacts to /tmp in unit tests

The saved model isn’t used. Remove to keep tests hermetic.

-    onnx.save(model, "/tmp/model_with_casted_output.onnx")

1076-1091: Strengthen assertions: verify Y1 connectivity and dtype

Only checking onnx.checker is weak. Assert that Y1 remains produced by a Cast and retains FP32 (keep_io_types=True).

     converted_model = converter.convert(
         high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
     )
-    onnx.checker.check_model(converted_model)
+    onnx.checker.check_model(converted_model)
+    # Y1 should remain connected and produced by a Cast
+    y1_producers = utils.get_producer_nodes(converted_model, "Y1")
+    assert len(y1_producers) == 1
+    assert y1_producers[0].op_type == "Cast"
+    # keep_io_types=True -> FP32 I/O preserved
+    y1_vi = next(y for y in converted_model.graph.output if y.name == "Y1")
+    assert y1_vi.type.tensor_type.elem_type == TensorProto.FLOAT
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 512dbb7 and 16d5875.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
  • check_model (557-569)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (113-202)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

607-617: Confirm intent: skipping only the duplicated “new_cast”, not the renamed original

casts_to_skip holds the original name (now assigned to new_cast) and will not skip the renamed original (..._io_special_case). If that original gets removed, _bypass_cast_node will reconnect its consumers directly to the model input. Is that intended? If not, add the renamed name to the skip list as well right after renaming.

-                    casts_to_skip.append(node.name)
+                    casts_to_skip.append(node.name)
                     casts_to_add.append(new_cast)
                     # Now adjust the old cast's name, consumers and producers
                     node.name = f"{node.name}_io_special_case"
+                    casts_to_skip.append(node.name)  # keep the internal IO-special-case cast as well

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (3)

600-606: Also preserve optional Cast attributes (e.g., saturate) to avoid semantic drift

If the model uses newer opsets where Cast may carry optional attributes (like saturate), the duplicate should copy them.

Apply:

-                    new_cast = helper.make_node(
-                        "Cast",
-                        name=node.name,
-                        inputs=[node.input[0]],
-                        outputs=[node.output[0]],
-                        to=utils.get_cast_to_type(node),
-                    )
+                    # Copy optional attributes (e.g., 'saturate' in newer opsets)
+                    saturate = next((a.i for a in node.attribute if a.name == "saturate"), None)
+                    cast_attrs = {"to": utils.get_cast_to_type(node)}
+                    if saturate is not None:
+                        cast_attrs["saturate"] = saturate
+                    new_cast = helper.make_node(
+                        "Cast",
+                        name=node.name,
+                        inputs=[node.input[0]],
+                        outputs=[node.output[0]],
+                        **cast_attrs,
+                    )

618-621: Insert duplicate Cast adjacent to the original for better locality and readability

Appending at the end works but scatters IO nodes. Insert near the renamed source node to keep topology readable.

-        for cast in casts_to_add:
-            self.model.graph.node.append(cast)
+        # Preserve locality: insert duplicates next to their originals
+        for cast in casts_to_add:
+            target_idx = -1
+            for i, n in enumerate(self.model.graph.node):
+                if n.name == f"{cast.name}_io_special_case":
+                    target_idx = i
+                    break
+            if target_idx >= 0:
+                self.model.graph.node.insert(target_idx, cast)
+            else:
+                # Fallback to prepend to avoid end-append reordering
+                self.model.graph.node.insert(0, cast)

592-596: Use a set for casts_to_skip from the start

Minor nit for clarity and O(1) membership checks.

-        casts_to_skip = []
+        casts_to_skip: set[str] = set()
         # Add casts as a separate step to avoid modifying the graph while iterating over it
         casts_to_add = []
@@
-                    casts_to_skip.append(node.name)
+                    casts_to_skip.add(node.name)
@@
-        casts_to_skip = set(casts_to_skip)
+        # already a set

Also applies to: 620-621

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 16d5875 and 01308d6.

📒 Files selected for processing (1)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

588-617: Solid IO-cast preservation strategy

Duplicating the IO-facing Cast, renaming the original, and rewiring consumers avoids disconnecting outputs while still enabling generic cast cleanup. This addresses the edge case cleanly.

value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails when keep_io_types=False due to a graph input becoming a graph output directly, which violates the assertion in ModelOpt that all original input and output names should be maintained in the quantized model.

Copy link

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed an edge case that challenges AutoCast's assumptions.
I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node.
@aboubezari @gcunhase

@gcunhase
Copy link
Contributor

This is indeed an edge case that challenges AutoCast's assumptions. I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node. @aboubezari @gcunhase

@galagam agree! @aboubezari, please make the suggested modification as a next step for this MR. Thanks!

@aboubezari
Copy link
Author

This is indeed an edge case that challenges AutoCast's assumptions. I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node. @aboubezari @gcunhase

@galagam agree! @aboubezari, please make the suggested modification as a next step for this MR. Thanks!

Will give this a try. Thanks.

@aboubezari
Copy link
Author

@gcunhase @galagam I've implemented the suggestion. Can you take another look?

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

1078-1097: Recompute mappings after sanitize() and assert outputs preserved

sanitize() mutates the graph (adds Identity and rewires). Recompute mappings to avoid stale references, and assert that both outputs remain present.

Apply this diff:

 def test_casted_input_to_output_model(
     model_with_casted_input_to_output, low_precision_type, keep_io_types
 ):
     model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output

     min_opset = 22 if low_precision_type == "bf16" else 13
     graph_sanitizer = GraphSanitizer(model, min_opset)
     graph_sanitizer.sanitize()
-    model = graph_sanitizer.model
+    model = graph_sanitizer.model
+    # Recompute mappings after graph mutation by sanitizer
+    value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
     converter = PrecisionConverter(
         model,
         value_info_map,
         initializer_map,
         node_to_init_map,
         keep_io_types=keep_io_types,
         low_precision_type=low_precision_type,
     )
     converted_model = converter.convert(
         high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
     )
     onnx.checker.check_model(converted_model)
+    # Ensure both original outputs are preserved
+    assert {o.name for o in converted_model.graph.output} == {"Y1", "Y2"}

Optionally, to validate type expectations on the direct-IO-cast output:

  • If keep_io_types is True: Y1/Y2 should be FLOAT.
  • If False: Y1/Y2 should be low_precision_onnx_type(low_precision_type).

1028-1067: Optional: Expand fixture to cover multi-Cast-from-same-input edge case

To guard against future regressions (multiple Casts from the same input to different outputs), add a companion fixture/test. Example snippet to append to this file:

@pytest.fixture
def model_with_two_casted_outputs():
    x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
    y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3])
    y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3])
    cast1 = helper.make_node("Cast", ["X"], ["Y1"], name="cast1", to=TensorProto.FLOAT)
    cast2 = helper.make_node("Cast", ["X"], ["Y2"], name="cast2", to=TensorProto.FLOAT)
    graph = helper.make_graph([cast1, cast2], "model_two_casted_outputs", [x], [y1, y2], [])
    model = helper.make_model(graph, producer_name="model_two_casted_outputs")
    model.opset_import[0].version = 20
    model.ir_version = 10
    onnx.checker.check_model(model)
    model = onnx_utils.infer_shapes(model)
    value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
    return model, value_info_map, initializer_map, node_to_init_map

@pytest.mark.parametrize("keep_io_types", [True, False])
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_two_casts_from_same_input(model_with_two_casted_outputs, keep_io_types, low_precision_type):
    model, value_info_map, initializer_map, node_to_init_map = model_with_two_casted_outputs
    gsani = GraphSanitizer(model, 22 if low_precision_type == "bf16" else 13)
    gsani.sanitize()
    model = gsani.model
    value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
    conv = PrecisionConverter(
        model, value_info_map, initializer_map, node_to_init_map,
        keep_io_types=keep_io_types, low_precision_type=low_precision_type
    )
    converted = conv.convert(high_precision_nodes=["cast1", "cast2"], low_precision_nodes=[])
    onnx.checker.check_model(converted)
    assert {o.name for o in converted.graph.output} == {"Y1", "Y2"}

I can send a PR update with these additions if helpful.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9363b09 and 03529fc.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/graphsanitizer.py (2 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-450)
  • sanitize (53-68)
🔇 Additional comments (2)
modelopt/onnx/autocast/graphsanitizer.py (1)

68-69: Good placement in sanitize()

Running sanitize_io_casts() after cleanup and IR versioning keeps the injected Identity from being swept early. LGTM.

tests/unit/onnx/autocast/test_precisionconverter.py (1)

23-24: Import addition looks good

GraphSanitizer is used below; import is correct and scoped for tests.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)

1077-1098: Consider adding explicit assertions for the edge case.

The test validates that the model passes check_model after conversion, which confirms basic correctness. However, it could be strengthened by explicitly asserting that:

  1. The Identity node was inserted by the sanitizer
  2. The original input and output names are preserved
  3. The Cast node is still present

This would make the test's intent clearer and provide better regression coverage for the specific edge case being addressed.

Example assertions to add before line 1098:

# Verify Identity node was inserted for the IO cast
identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"]
assert len(identity_nodes) >= 1, "Expected at least one Identity node for IO cast isolation"

# Verify original I/O names are preserved
assert converted_model.graph.input[0].name == "X"
assert any(o.name == "Y1" for o in converted_model.graph.output)
assert any(o.name == "Y2" for o in converted_model.graph.output)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between caf9d39 and 0e0d11a.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/precisionconverter.py (5 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
  • check_model (557-569)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (120-211)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-460)
  • sanitize (53-68)
🔇 Additional comments (6)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

33-33: LGTM! Constant definition is clear.

The constant LATEST_IR_VERSION_SUPPORTED_BY_ORT is well-named and appropriately scoped for test usage.


1031-1074: LGTM! Fixture correctly models the edge case.

The fixture constructs a graph where a Cast node connects input X directly to output Y1, alongside a separate computation path. This accurately captures the edge case described in the PR objectives.

modelopt/onnx/autocast/precisionconverter.py (4)

35-35: LGTM! Import is correctly placed.

The GraphSanitizer import is appropriately added to support the new sanitization step.


77-79: LGTM! New parameters are well-integrated.

The three new parameters (min_opset, max_ir_version, trt_plugins) are properly added to the constructor with sensible defaults and stored as instance variables for use in the sanitization step.

Also applies to: 116-118


142-142: LGTM! Sanitization is correctly positioned in the workflow.

The _sanitize_model() call is appropriately placed early in the convert() method, after model validation but before unsupported op filtering. This ensures the graph is normalized (e.g., IO casts are isolated with Identity nodes) before precision conversion logic executes.


1043-1051: LGTM! Sanitization method is clean and correct.

The _sanitize_model() method correctly instantiates GraphSanitizer with the model and relevant parameters, invokes sanitization, and updates the model. This cleanly integrates the graph normalization step (including IO cast isolation with Identity nodes) into the precision conversion workflow.

@aboubezari aboubezari force-pushed the fix_auto_cast_io_special_nodes_bug branch from 0e0d11a to 7dd644a Compare October 1, 2025 15:08
Update modelopt/onnx/autocast/precisionconverter.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: aboubezari <[email protected]>

cleanup

Signed-off-by: Ali Boubezari <[email protected]>

Inject identity nodes in sanitizer; revert existing logic; update test

Signed-off-by: Ali Boubezari <[email protected]>

Inject identity nodes in sanitizer; revert existing logic; update test

Signed-off-by: Ali Boubezari <[email protected]>

move pass

Signed-off-by: Ali Boubezari <[email protected]>

call sanitizer in precision converter

Signed-off-by: Ali Boubezari <[email protected]>
@aboubezari aboubezari force-pushed the fix_auto_cast_io_special_nodes_bug branch from 7dd644a to 4028e2a Compare October 1, 2025 15:10
@aboubezari
Copy link
Author

@galagam @gcunhase I've improved the fix and confirmed it fixes some customer issues we've been having. Please take a look and let me know if there's anything else to change. Thanks!

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

77-92: Document the new parameters in the docstring.

The three new parameters (min_opset, max_ir_version, trt_plugins) are not documented in the docstring.

Apply this diff:

         Args:
             model: ONNX model to convert.
             value_info_map: Map of tensor names to value info.
             initializer_map: Map of tensor names to initializers.
             node_to_init_map: Map of node names to lists of initializer names.
             keep_io_types: Keep the input and output types of the model, otherwise they will be converted.
             low_precision_type: Precision to convert to.
             init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
                                        cast at runtime.
             custom_ops: List of custom ops.
+            min_opset: Minimum opset version to use for sanitization (default: 13).
+            max_ir_version: Maximum IR version supported by ORT (default: None).
+            trt_plugins: List of TensorRT plugin library paths in .so format (default: []).
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

1063-1071: Consider matching parameter order to the GraphSanitizer constructor signature.

While using keyword arguments makes order irrelevant, matching the constructor's parameter order improves readability.

Apply this diff:

     def _sanitize_model(self):
         graph_sanitizer = GraphSanitizer(
             self.model,
             self.min_opset,
-            trt_plugins=self.trt_plugins,
             max_ir_version=self.max_ir_version,
+            trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
         )
         graph_sanitizer.sanitize()
         self.model = graph_sanitizer.model
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e0d11a and 4028e2a.

📒 Files selected for processing (3)
  • modelopt/onnx/autocast/graphsanitizer.py (2 hunks)
  • modelopt/onnx/autocast/precisionconverter.py (5 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unit/onnx/autocast/test_precisionconverter.py
  • modelopt/onnx/autocast/graphsanitizer.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-560)
  • sanitize (53-69)
🔇 Additional comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

35-35: LGTM!

The import is correctly placed and necessary for the sanitization functionality.


142-143: LGTM!

The sanitization is correctly invoked early in the conversion workflow, after validation but before precision conversion operations.

Comment on lines +77 to +79
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace mutable default argument with None.

The default value trt_plugins=[] creates a mutable default argument, which can lead to shared state between instances if the list is mutated.

Apply this diff:

-        trt_plugins: list[str] | None = [],
+        trt_plugins: list[str] | None = None,

Then update line 1067 to handle the None case:

         graph_sanitizer = GraphSanitizer(
             self.model,
             self.min_opset,
-            trt_plugins=self.trt_plugins,
+            trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
             max_ir_version=self.max_ir_version,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
def __init__(
...,
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = None,
):
...
Suggested change
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
graph_sanitizer = GraphSanitizer(
self.model,
self.min_opset,
trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
max_ir_version=self.max_ir_version,
)
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 77 to 79, change
the function signature to use trt_plugins: list[str] | None = None instead of a
mutable default list, and then at line 1067 update the code to treat a None
value as an empty list (e.g., set local_trt_plugins = trt_plugins or [] before
using it) so any subsequent iterations or mutations operate on a fresh list
rather than a shared default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants