Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Model Optimizer Changelog (Linux)
**New Features**

- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
- Add flag ``nodes_to_include`` and ``op_types_to_include`` in AutoCast to force-include nodes in low precision, even if they would otherwise be excluded by other rules.

0.37 (2025-09-xx)
^^^^^^^^^^^^^^^^^
Expand Down
15 changes: 15 additions & 0 deletions docs/source/guides/8_autocast.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ AutoCast can also be used programmatically through its Python API:
low_precision_type="fp16", # or "bf16"
nodes_to_exclude=None, # optional list of node name patterns to keep in FP32
op_types_to_exclude=None, # optional list of op types to keep in FP32
nodes_to_include=None, # optional list of node name patterns to force-include in low precision
op_types_to_include=None, # optional list of op types to force-include in low precision
data_max=512, # threshold for node outputs
init_max=65504, # threshold for initializers
keep_io_types=False, # whether to preserve input/output types
Expand Down Expand Up @@ -60,6 +62,19 @@ AutoCast follows these steps to convert a model:
- Analyzes each node in the graph
- Determines which nodes should remain in FP32 based on input and output tensors magnitudes, operation types and node name patterns
- If a calibration dataset is provided, it will be used to generate intermediate tensor magnitudes for more accurate node classification, otherwise random data will be used.
- Use ``nodes_to_include`` and ``op_types_to_include`` to force-include nodes in low precision, even if they would otherwise be excluded.

- Default classification rules. Nodes that meet any of these rules will be kept in high precision:
- Node I/O magnitudes are higher than ``data_max`` (default: 512). Due to precision limitations, compute of high magnitude tensors in low precision might not be accurate. The unit in last place (ULP) for 512 is 0.5, for 1024 it is 1.0, etc.
- Initializers magnitudes are higher than ``init_max`` (default: 65504). Initializers are often used for non-compute intensive operations and are more likely to be controlled by the user. However, values above ``init_max`` will cause overflow, therefore they are kept in high precision.

Additional classification rules (disabled by default):
- ``max_depth_of_reduction``: Require nodes with a high depth of reduction (e.g., large matrix multiplications, convolutions with large kernels) to be kept in high precision.
- ``nodes_to_exclude``: List of regex patterns for node names to keep in high precision.
- ``op_types_to_exclude``: List of operation types to keep in high precision.
- ``nodes_to_include``: List of regex patterns for node names to force-include in low precision.
- ``op_types_to_include``: List of operation types to force-include in low precision.
- ``custom_rule``: Optional custom rule for node classification (inherits from NodeRuleBase).

#. **Precision Conversion**:

Expand Down
23 changes: 22 additions & 1 deletion modelopt/onnx/autocast/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ def get_parser() -> argparse.ArgumentParser:
default=[],
help="List of op types that should remain in FP32",
)
parser.add_argument(
"--nodes_to_include",
"-ni",
type=str,
nargs="*",
default=[],
help="List of regex patterns to match node names that should be force-included in low precision, even if they "
"would otherwise be excluded",
)
parser.add_argument(
"--op_types_to_include",
"-opi",
type=str,
nargs="*",
default=[],
help="List of op types that should be force-included in low precision, even if they would otherwise be "
"excluded",
)
parser.add_argument(
"--data_max",
type=float,
Expand Down Expand Up @@ -112,7 +130,8 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--keep_io_types",
action="store_true",
help="Keep the input and output types of the model, otherwise they will be converted to FP16",
help="Keep the input and output types of the model; otherwise they will be converted to reduced precision "
"(FP16/BF16)",
)
parser.add_argument(
"--log_level",
Expand Down Expand Up @@ -164,6 +183,8 @@ def main(argv=None):
low_precision_type=args.low_precision_type,
nodes_to_exclude=args.nodes_to_exclude,
op_types_to_exclude=args.op_types_to_exclude,
nodes_to_include=args.nodes_to_include,
op_types_to_include=args.op_types_to_include,
data_max=args.data_max,
init_max=args.init_max,
keep_io_types=args.keep_io_types,
Expand Down
8 changes: 7 additions & 1 deletion modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""AutoCast module for converting ONNX models to mixed precision.

AutoCast is a tool for converting FP32 ONNX models to mixed precision FP32-FP16 or FP32-BF16 models.
While casting FP32 to FP6/BF16, some nodes might be more sensitive to effecting accuracy.
While casting FP32 to FP16/BF16, some nodes might be more sensitive to effecting accuracy.
AutoCast intelligently selects nodes to keep in FP32 precision to maintain model accuracy while benefiting from
reduced precision on the rest of the nodes. AutoCast automatically injects cast operations around the selected
nodes.
Expand Down Expand Up @@ -48,6 +48,8 @@ def convert_to_mixed_precision(
low_precision_type: str = "fp16",
nodes_to_exclude: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
nodes_to_include: list[str] | None = None,
op_types_to_include: list[str] | None = None,
data_max: float = DEFAULT_DATA_MAX,
init_max: float = DEFAULT_INIT_MAX,
keep_io_types: bool = False,
Expand All @@ -65,6 +67,8 @@ def convert_to_mixed_precision(
low_precision_type: Target precision to reduce to ('fp16' or 'bf16').
nodes_to_exclude: List of regex patterns to match node names that should remain in FP32.
op_types_to_exclude: List of operation types that should remain in FP32.
nodes_to_include: List of regex patterns to match node names that should be included in low precision.
op_types_to_include: List of operation types that should be included in low precision.
data_max: Maximum absolute value for node input and output values.
init_max: Maximum absolute value for initializers.
keep_io_types: Whether to preserve input/output types.
Expand Down Expand Up @@ -108,6 +112,8 @@ def convert_to_mixed_precision(
initializer_map,
nodes_to_exclude=nodes_to_exclude or [],
op_types_to_exclude=op_types_to_exclude or [],
nodes_to_include=nodes_to_include or [],
op_types_to_include=op_types_to_include or [],
data_max=data_max,
init_max=init_max,
custom_rule=custom_rule,
Expand Down
51 changes: 48 additions & 3 deletions modelopt/onnx/autocast/nodeclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,28 @@ def _check_inner(self, node):
return node.op_type in self.op_types_to_exclude


class IncludeNodeNameRegexRule(DisabledNodeNameRegexRule):
"""Rule for force-including nodes with matching names in low precision.

Inherits matching behavior from DisabledNodeNameRegexRule but overrides logging.
"""

def _log_skipped(self, node, **kwargs):
# For include rules, a positive match means we will force-include the node in low precision
logger.info(f"Force-including node {node.name}: {self.__class__.__name__}")


class IncludeOpTypes(DisabledOpTypes):
"""Rule for force-including specific operation types in low precision.

Inherits matching behavior from DisabledOpTypes but overrides logging.
"""

def _log_skipped(self, node, **kwargs):
# For include rules, a positive match means we will force-include the node in low precision
logger.info(f"Force-including node {node.name}: {self.__class__.__name__}")


class InitializerRangeRule(NodeRuleBase):
"""Rule for keeping nodes with out-of-range initializers in high precision."""

Expand Down Expand Up @@ -332,6 +354,8 @@ def __init__(
initializer_map: dict[str, onnx.TensorProto] | None = None,
nodes_to_exclude: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
nodes_to_include: list[str] | None = None,
op_types_to_include: list[str] | None = None,
custom_rule: NodeRuleBase | None = None,
data_max: float | None = 1000.0,
init_max: float | None = np.finfo(np.float16).max,
Expand All @@ -345,6 +369,8 @@ def __init__(
initializer_map: Mapping from initializer names to their tensors.
nodes_to_exclude: List of regex patterns for node names to keep in high precision.
op_types_to_exclude: List of operation types to keep in high precision.
nodes_to_include: List of regex patterns for node names to force-include in low precision.
op_types_to_include: List of operation types to force-include in low precision.
custom_rule: Optional custom classification rule.
data_max: Maximum absolute value allowed for node I/O.
init_max: Maximum absolute value allowed for initializers.
Expand All @@ -355,12 +381,14 @@ def __init__(
self.initializer_map = initializer_map
self.nodes_to_exclude = nodes_to_exclude
self.op_types_to_exclude = op_types_to_exclude
self.nodes_to_include = nodes_to_include
self.op_types_to_include = op_types_to_include
self.custom_rule = custom_rule
self.data_max = data_max
self.init_max = init_max
self.max_depth_of_reduction = max_depth_of_reduction

def _gen_block_node_rules(self, reference_data):
def _gen_exclude_node_rules(self, reference_data):
"""Generate list of rules for blocking nodes from precision conversion.

Args:
Expand Down Expand Up @@ -393,6 +421,20 @@ def _gen_block_node_rules(self, reference_data):
block_node_rules.append(self.custom_rule)
return block_node_rules

def _gen_include_node_rules(self):
"""Generate list of rules for force-including nodes in low precision.

Returns:
list[NodeRuleBase]: List of rules to apply.
"""
include_node_rules: list[NodeRuleBase] = []
if self.nodes_to_include:
include_node_rules.append(IncludeNodeNameRegexRule(self.nodes_to_include))
if self.op_types_to_include:
include_node_rules.append(IncludeOpTypes(self.op_types_to_include))

return include_node_rules

def run(self, ref_outputs_dict=None):
"""Run node classification.

Expand All @@ -402,12 +444,15 @@ def run(self, ref_outputs_dict=None):
Returns:
tuple: Lists of node names (low_precision_nodes, high_precision_nodes).
"""
block_node_rules = self._gen_block_node_rules(ref_outputs_dict)
exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict)
include_node_rules = self._gen_include_node_rules()
low_precision_nodes = []
high_precision_nodes = []
for node in self.model.graph.node:
# If any condition is met - node will be executed in high precision
if any(rule.check(node) for rule in block_node_rules):
if any(rule.check(node) for rule in exclude_node_rules) and not any(
rule.check(node) for rule in include_node_rules
):
high_precision_nodes.append(node.name)
else:
low_precision_nodes.append(node.name)
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/onnx/autocast/test_nodeclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,55 @@ def test_node_classifier_op_types_to_exclude(test_model):
assert len(fp16_nodes) + len(fp32_nodes) == 2
# Test that no node is in both fp16 and fp32 lists
assert not set(fp16_nodes).intersection(set(fp32_nodes))


# Test that nodes_to_include and op_types_to_include force nodes into low precision,
# even if they would otherwise be excluded by other rules.
def test_node_classifier_force_include(test_model):
node_to_init_map = {
"add_node": [
numpy_helper.from_array(np.array([[10.0, 20.0], [30.0, 40.0]], dtype=np.float32))
],
"mul_node": [numpy_helper.from_array(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))],
}

# Set init_max low so both nodes would normally be excluded (kept in FP32)
# Force add_node to low precision, despite exceeding init_max
classifier = NodeClassifier(
model=test_model,
node_to_init_map=node_to_init_map,
init_max=1.0,
nodes_to_include=["add_node"],
)
fp16_nodes, fp32_nodes = classifier.run()
# add_node should be in fp16_nodes due to nodes_to_include, despite exceeding data_max
assert "add_node" in fp16_nodes
assert "mul_node" in fp32_nodes
assert "add_node" not in fp32_nodes
assert len(fp16_nodes) + len(fp32_nodes) == 2
assert not set(fp16_nodes).intersection(set(fp32_nodes))

# Test that include op rule override exclude op rule
classifier2 = NodeClassifier(
model=test_model,
node_to_init_map=node_to_init_map,
op_types_to_exclude=["Add"],
nodes_to_include=["add_node"], # Should override op_types_to_exclude
)
fp16_nodes, fp32_nodes = classifier2.run()
assert "add_node" in fp16_nodes
assert "add_node" not in fp32_nodes
assert not set(fp16_nodes).intersection(set(fp32_nodes))

# Set init_max low so both nodes would normally be excluded (kept in FP32)
# Force op type Mul to low precision, despite exceeding init_max
classifier3 = NodeClassifier(
model=test_model,
node_to_init_map=node_to_init_map,
init_max=1.0,
op_types_to_include=["Mul"],
)
fp16_nodes, fp32_nodes = classifier3.run()
assert "mul_node" in fp16_nodes
assert "add_node" in fp32_nodes
assert not set(fp16_nodes).intersection(set(fp32_nodes))
Loading