diff --git a/pyproject.toml b/pyproject.toml index 204f1d6..f98771e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"] +dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "sympy"] [project.urls] Homepage = "https://onnx.ai/ir-py" diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index af5a258..84281bf 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -45,6 +45,8 @@ import ml_dtypes import numpy as np +import sympy +import sympy.utilities.misc from typing_extensions import TypeIs import onnx_ir @@ -1115,13 +1117,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): It is immutable and can be compared or hashed. """ - __slots__ = ("_value",) + __slots__ = ("_expr", "_value") - def __init__(self, value: str | None) -> None: + def __init__(self, value: str | None, /, expr: sympy.Expr | None = None) -> None: """Initialize a symbolic dimension. Args: value: The value of the dimension. It should not be an int. + expr: An optional sympy expression representing the dimension. Raises: TypeError: If value is an int. @@ -1132,6 +1135,7 @@ def __init__(self, value: str | None) -> None: "If you are creating a Shape, use int directly instead of SymbolicDim." ) self._value = value + self._expr: sympy.Expr | None = expr def __eq__(self, other: object) -> bool: """Check equality with another SymbolicDim or string/None.""" @@ -1148,11 +1152,24 @@ def value(self) -> str | None: """The value of the symbolic dimension (string or None).""" return self._value + @property + def expr(self) -> sympy.Expr | None: + """The sympy expression representing the symbolic dimension.""" + return self._expr + def __str__(self) -> str: - return f"{self._value}" + if self._value is not None: + return str(self._value) + if self._expr is not None: + return str(self._expr) + return "?" def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._value})" + if self._expr is not None: + expr_text = f", expr={self._expr!r}" + else: + expr_text = "" + return f"{self.__class__.__name__}({self._value}{expr_text})" def _is_int_compatible(value: object) -> TypeIs[SupportsInt]: @@ -1190,10 +1207,16 @@ def _maybe_convert_to_symbolic_dim( return SymbolicDim(dim) if _is_int_compatible(dim): return int(dim) + if isinstance(dim, sympy.Expr): + # If the dimension is a sympy expression, we create a SymbolicDim with it + expr = sympy.sympify(dim) + if expr.is_integer: + return sympy.utilities.misc.as_int(expr) + return SymbolicDim(str(expr), expr=sympy.sympify(expr)) if isinstance(dim, SymbolicDim): return dim raise TypeError( - f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" + f"Expected int, str, sympy.Expr, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" ) @@ -1334,7 +1357,9 @@ def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ... def __getitem__(self, index): return tuple(self._dims)[index] - def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None: + def __setitem__( + self, index: int, value: int | SymbolicDim | str | sympy.Expr | None + ) -> None: """Set the dimension at the index. Args: diff --git a/src/onnx_ir/_shape_type_inference/README.md b/src/onnx_ir/_shape_type_inference/README.md new file mode 100644 index 0000000..29094c9 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/README.md @@ -0,0 +1,262 @@ +# Symbolic Shape and Type Inference + +This module provides symbolic shape and type inference for ONNX IR models, enabling compile-time analysis of tensor shapes and types with support for symbolic dimensions. + +## Overview + +The inference engine performs forward propagation through ONNX models to determine output shapes and types based on input specifications. It supports symbolic dimensions using SymPy expressions, allowing for dynamic shape analysis. + +## Key Components + +### Core Classes + +- **`SymbolicInferenceEngine`**: Main orchestrator that processes models and applies inference +- **`NodeInferrer`**: Base class for operation-specific inference logic +- **`InferenceResult`**: Container for inference results with status and optional message +- **`InferenceStatus`**: Enum for inference operation status (SUCCESS, PARTIAL, MISSING_INFO, INVALID_NODE) + +### Reconciliation Policies + +The engine supports different strategies for handling conflicts between inferred and existing values: + +- **`OVERWRITE`**: Always use inferred values +- **`IGNORE`**: Keep existing values if they exist +- **`RECONCILE`**: Merge inferred and existing values intelligently +- **`STRICT`**: Fail if inferred values don't match existing ones + +### Inference Status System + +The `InferenceResult` uses a status-based approach for granular error handling: + +- **`SUCCESS`**: Complete inference successful with full shape/type information +- **`PARTIAL`**: Partial information available (e.g., type only, rank only) +- **`MISSING_INFO`**: Missing required input information (shapes, types) +- **`INVALID_NODE`**: Node is invalid or malformed + +```python +# Example usage in inferrers +def infer(self, node: ir.Node) -> InferenceResult: + if node.inputs[0].shape is None: + return InferenceResult( + status="missing_info", + msg="Input shape is required" + ) + + # Partial inference - only type available + if can_infer_type_only(): + return InferenceResult( + values=[ir.Value(type=inferred_type)], + status="partial", + msg="Shape unavailable, type only" + ) + + # Full inference (status defaults to "success") + return InferenceResult(values=[full_value]) +``` + +## Architecture + +```text +SymbolicInferenceEngine +├── NodeInferrer Registry (by op_type + domain) +├── Opset Version Matching +└── Reconciliation Logic + +NodeInferrer Implementations +├── ElementwiseInferrer (unary operations) +├── BinaryInferrer (broadcasting operations) +└── Specialized Inferrers (50+ operations) +``` + +## Inferrer Selection + +The engine selects the appropriate inferrer using a two-stage process: + +1. **Registry Lookup**: Inferrers are registered by `(op_type, domain)` key +2. **Opset Matching**: Among matching inferrers, select those supporting the model's opset version + +```python +# Example: For a Squeeze node with opset 14 +# - Multiple Squeeze inferrers may be registered +# - Engine selects Squeeze13Inferrer (supports opset 13-23) +# - Ignores Squeeze12Inferrer (supports opset 1-12) +``` + +## Symbolic Dimensions + +The system stores symbolic expressions in `ir.SymbolicDim` objects: + +```python +class SymbolicDim: + value: str | None # String identifier (e.g., "N", "batch_size") + expr: sympy.Expr | None # SymPy expression for computed dimensions +``` + +Dimensions are accessed via `get_expr()` which converts to SymPy expressions: + +- `SymbolicDim(value="N")` → `sympy.Symbol("N")` +- `SymbolicDim(expr=N*2)` → `N*2` (SymPy expression) +- Integer dimensions → `sympy.Integer(value)` + +## NodeInferrer Design Decisions + +### Base Class Structure +The `NodeInferrer` abstract base class enforces a consistent interface: + +```python +class NodeInferrer(abc.ABC): + def __init__(self, op_type: str, opsets: Collection[int], domain: str = ""): + # Store operation metadata for registry matching + + @abc.abstractmethod + def infer(self, node: ir.Node) -> InferenceResult: + # Operation-specific inference logic +``` + +### Design Rationale + +1. **Single Responsibility**: Each inferrer handles exactly one operation type +2. **Opset Awareness**: Inferrers declare supported ONNX opset versions for compatibility +3. **Domain Support**: Enables custom domains beyond standard ONNX operators +4. **Validation Decorators**: `@requires_non_none_inputs(n)` and `@requires_outputs(n)` provide consistent input validation +5. **Failure Handling**: Return `InferenceResult` with either `values` or `failure` for graceful error handling + +### Inheritance Patterns + +- **ElementwiseInferrer**: Template for unary operations that preserve input shape/type +- **BinaryInferrer**: Template for binary operations with broadcasting logic +- **Specialized Inferrers**: Custom logic for complex operations (Conv, Reshape, etc.) + +## Usage + +### Basic Usage + +```python +from onnx_ir._shape_type_inference.factory import create_standard_inference_engine +from onnx_ir._shape_type_inference import ReconciliationPolicy + +# Create engine with all standard operations +engine = create_standard_inference_engine(ReconciliationPolicy.RECONCILE) + +# Perform inference on a model +engine.infer_model(model) +``` + +### Custom Engine + +```python +from onnx_ir._shape_type_inference import SymbolicInferenceEngine +from onnx_ir._shape_type_inference.ops.matmul import MatMulInferrer +from onnx_ir._shape_type_inference.ops.standard_ops import BinaryInferrer + +# Create custom engine with specific operations +inferrers = [ + MatMulInferrer(), + BinaryInferrer("Add"), + BinaryInferrer("Mul"), +] + +engine = SymbolicInferenceEngine(inferrers, ReconciliationPolicy.STRICT) +``` + +## Opset Version Support + +Each inferrer specifies supported ONNX opset versions to handle API changes: + +```python +class Squeeze12Inferrer(NodeInferrer): + def __init__(self): + super().__init__("Squeeze", opsets=range(1, 13)) + +class Squeeze13Inferrer(NodeInferrer): + def __init__(self): + super().__init__("Squeeze", opsets=range(13, 24)) +``` + +## Error Handling + +The engine provides comprehensive error handling: + +- **Validation Errors**: Invalid input/output counts, missing shapes +- **Type Mismatches**: Incompatible input types for binary operations +- **Inference Failures**: Operation-specific inference errors +- **Reconciliation Conflicts**: Value mismatches in strict mode + +## Factory Functions + +Pre-configured engines for common use cases: + +- **`create_standard_inference_engine()`**: Full operation coverage (50+ ops) +- **`create_minimal_inference_engine()`**: Essential operations only + +## Subgraphs and ONNX Functions + +### Design Approach + +#### Subgraph Pre-Processing Strategy + +The engine uses a **subgraph-first** approach for cleaner separation of concerns: + +1. **Pre-Processing Phase**: Before running node inference, detect and recursively process all subgraphs +2. **Bottom-Up Inference**: Subgraphs are fully inferred before their parent nodes +3. **Simplified Node Logic**: Control flow inferrers (If, Loop, Scan) can assume subgraph shapes are already available + +```python +class SymbolicInferenceEngine: + def _infer_node(self, node: ir.Node, model: ir.Model) -> None: + # First: recursively infer any subgraphs + for attr in node.attributes: + if isinstance(attr.value, ir.Graph): + self._infer_subgraph(attr.value, model) + + # Then: run node-specific inference with subgraphs already processed + inferrer = self._find_inferrer(node, model) + result = inferrer.infer(node) # Subgraph shapes already available +``` + +#### ONNX Function Support + +Functions are handled through **automatic expansion** without custom inferrer logic: + +1. **Function Context**: Engine maintains intermediate value mappings during function execution +2. **Transparent Expansion**: Function calls are expanded inline and processed like regular subgraphs +3. **No Custom Logic**: Users don't implement function-specific inferrers - the engine handles it automatically + +```python +class SymbolicInferenceEngine: + def _infer_function_call(self, node: ir.Node, function: ir.Function) -> InferenceResult: + # Create isolated context for function execution + function_context = self._create_function_context(node.inputs, function) + + # Process function body as a subgraph + for func_node in function.nodes: + self._infer_node_in_context(func_node, function_context) + + # Map function outputs back to caller node + return self._extract_function_outputs(function_context, function.outputs) +``` + +### Key Benefits + +1. **Cleaner Separation**: Subgraph inference is handled by the engine, not individual inferrers +2. **Automatic Function Support**: No need to implement custom logic for each function +3. **Simplified Debugging**: Each phase (subgraphs → nodes) can be debugged independently +4. **Consistent Context**: Function calls maintain proper variable scoping and type consistency + +## Extension Points + +To add support for new operations: + +1. Create a new inferrer class inheriting from `NodeInferrer` +2. Implement the `infer()` method with operation-specific logic +3. Register with the engine or add to factory functions + +```python +class CustomOpInferrer(NodeInferrer): + def __init__(self): + super().__init__("CustomOp", opsets=range(1, 24), domain="custom_domain") + + def infer(self, node: ir.Node) -> InferenceResult: + # Custom inference logic + return InferenceResult(values=[result_value]) +``` diff --git a/src/onnx_ir/_shape_type_inference/__init__.py b/src/onnx_ir/_shape_type_inference/__init__.py new file mode 100644 index 0000000..a9a892e --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/__init__.py @@ -0,0 +1,20 @@ +"""Symbolic shape and type inference for ONNX IR.""" + +__all__ = [ + "SymbolicInferenceEngine", + "InferenceError", + "NodeInferrer", + "InferenceResult", + "InferenceStatus", +] + + +from onnx_ir._shape_type_inference._common import ( + InferenceResult, + InferenceStatus, + NodeInferrer, +) +from onnx_ir._shape_type_inference._engine import ( + InferenceError, + SymbolicInferenceEngine, +) diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py new file mode 100644 index 0000000..d2da467 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -0,0 +1,195 @@ +"""Symbolic shape inference for ONNX IR.""" + +from __future__ import annotations + +import abc +import enum +import functools +from collections.abc import Collection, Sequence +from typing import Any, Callable + +import sympy + +import onnx_ir as ir + +MAX_SUPPORTED_OPSET = 23 + + +def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: + """Get the expression or value at a specific index in the shape. + + Args: + shape: The shape to get the expression from. + index: The index of the dimension to get. + + Returns: + The expression or value at the specified index. + """ + dim = shape[index] + if isinstance(dim, ir.SymbolicDim): + if dim.expr is not None: + return dim.expr + if dim.value is None: + return sympy.Symbol("__unknown__") + return sympy.Symbol(dim.value) + return sympy.Integer(dim) + + +@enum.unique +class InferenceStatus(enum.Enum): + """Status of shape inference operation.""" + + SUCCESS = "success" # Complete inference successful + PARTIAL = "partial" # Partial information available (e.g., type only, rank only) + MISSING_INFO = "missing_info" # Missing required input information + INVALID_NODE = "invalid_node" # Node is invalid or malformed + + +class InferenceResult: + """Container for inference results with status and optional message.""" + + def __init__( + self, + values: Sequence[ir.Value] | None = None, + status: str | InferenceStatus = "success", + msg: str | None = None, + ) -> None: + """Initialize inference result. + + Args: + values: Sequence of inferred values. + status: Status of inference operation (string or enum). + msg: Optional message for context. + """ + self.values = values + self.status = InferenceStatus(status) + self.msg = msg + + def __repr__(self) -> str: + """Return string representation of the result.""" + return f"InferenceResult(values={self.values}, status={self.status.value}, msg={self.msg!r})" + + +class NodeInferrer(abc.ABC): + """Base class for node inferrers. + + This class provides a common interface for all node inferrers. + """ + + def __init__( + self, op_type: str, opsets: Collection[int], domain: str = "", overload: str = "" + ) -> None: + """Initialize the node inferrer. + + Args: + op_type: The type of the operation. + opsets: A collection of ONNX opset versions supported by this inferrer. + domain: The domain of the operation, default is an empty string. + overload: The overload identifier for the operation, default is an empty string. + """ + self.op_type = op_type + self.opsets = opsets + self.domain = domain + self.overload = overload + + def __repr__(self) -> str: + """Return a string representation of the node inferrer.""" + return f"{self.__class__.__name__}(op_type={self.op_type}, opsets={self.opsets}, domain={self.domain})" + + @abc.abstractmethod + def infer(self, node: ir.Node) -> InferenceResult: + """Infer the shape for the node. + + Args: + node: The ONNX node to infer the type and shape for. + + Returns: + A sequence of ONNX values containing the inferred shapes. + """ + raise NotImplementedError + + +def requires_non_none_inputs( + count: int, / +) -> Callable[ + [Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] +]: + """Ensure that the node has a specific number of non-None inputs. + + Args: + count: The exact number of non-None inputs required for the node. + + Returns: + A decorator that checks the number of inputs and their non-None status. + """ + + def decorator( + func: Callable[[Any, ir.Node], InferenceResult], + ) -> Callable[[Any, ir.Node], InferenceResult]: + @functools.wraps(func) + def wrapper(self, node: ir.Node) -> InferenceResult: + if len(node.inputs) != count: + return InferenceResult( + status="invalid_node", + msg=f"{node.op_type} must have {count} inputs, got {len(node.inputs)}.", + ) + for i, inp in enumerate(node.inputs): + if inp is None: + return InferenceResult( + status="missing_info", msg=f"{node.op_type} input {i} cannot be None." + ) + return func(self, node) + + return wrapper + + return decorator + + +def requires_outputs( + count: int, / +) -> Callable[ + [Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] +]: + """Ensure that the node has a specific number of outputs. + + Args: + count: The exact number of outputs required for the node. + + Returns: + A decorator that checks the number of outputs. + """ + + def decorator( + func: Callable[[Any, ir.Node], InferenceResult], + ) -> Callable[[Any, ir.Node], InferenceResult]: + @functools.wraps(func) + def wrapper(self, node: ir.Node) -> InferenceResult: + if len(node.outputs) != count: + return InferenceResult( + status="invalid_node", + msg=f"{node.op_type} must have {count} outputs, got {len(node.outputs)}.", + ) + return func(self, node) + + return wrapper + + return decorator + + +def inclusive_range(start_or_end: int = 0, end: int | None = None) -> range: + """Create an inclusive range from start to end with a given step. + + Args: + start_or_end: The starting value of the range. + end: The ending value of the range (inclusive). + + Returns: + A range object that includes both start and end. + """ + if end is None: + end = start_or_end + start = 0 + else: + start = start_or_end + + return range(start, end + 1) diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py new file mode 100644 index 0000000..bdf8e2b --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -0,0 +1,325 @@ +"""Symbolic inference engine for ONNX IR models.""" + +from __future__ import annotations + +import enum +import logging +from collections.abc import Iterable, Sequence + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + +logger = logging.getLogger(__name__) + + +class ReconciliationPolicy(enum.Enum): + """Policy for reconciling inferred shapes/types with existing values.""" + + OVERWRITE = "overwrite" # Always use inferred values + IGNORE = "ignore" # Keep existing values if they exist + RECONCILE = "reconcile" # Try to merge/validate inferred vs existing + STRICT = "strict" # Fail if inferred doesn't match existing + + +class InferenceError(RuntimeError): + """Error during shape inference.""" + + +class SymbolicInferenceEngine: + """Engine for performing symbolic shape and type inference on ONNX IR models.""" + + def __init__( + self, + node_inferrers: Iterable[_common.NodeInferrer], + reconciliation_policy: str = "reconcile", + ) -> None: + """Initialize the symbolic inference engine. + + Args: + node_inferrers: List of node inferrers to use for shape inference. + reconciliation_policy: Policy for handling conflicts between inferred and existing values. + """ + self.reconciliation_policy = ReconciliationPolicy(reconciliation_policy) + self._inferrer_registry: dict[ir.OperatorIdentifier, list[_common.NodeInferrer]] = {} + + # Register inferrers by (op_type, domain) + for inferrer in node_inferrers: + key = (inferrer.domain, inferrer.op_type, inferrer.overload) + self._inferrer_registry.setdefault(key, []).append(inferrer) + + def infer_model(self, model: ir.Model) -> None: + """Perform shape and type inference on an entire model. + + Args: + model: The ONNX IR model to perform inference on. + + Raises: + InferenceError: If inference fails for any node. + """ + logger.info("Starting inference on model with %s nodes", len(model.graph)) + + # Process nodes in topological order + for i, node in enumerate(model.graph): + try: + self._infer_node(node, model) + logger.debug("Successfully inferred node %s: %s", i, node.op_type) + except Exception as e: + error_msg = f"Failed to infer node {i} ({node.op_type}): {e}" + logger.exception(error_msg) + raise InferenceError(error_msg) from e + + logger.info("Model inference completed successfully") + + def _infer_node(self, node: ir.Node, model: ir.Model) -> None: + """Perform inference on a single node. + + Args: + node: The node to perform inference on. + model: The model containing the node (for context). + + Raises: + InferenceError: If no suitable inferrer is found or inference fails. + """ + # Find suitable inferrer + inferrer = self._find_inferrer(node, model) + if inferrer is None: + raise InferenceError( + f"No inferrer found for op_type '{node.op_type}' domain '{node.domain}'" + ) + + # Perform inference + result = inferrer.infer(node) + + if result.status == _common.InferenceStatus.INVALID_NODE: + # TODO: Print the node information + raise InferenceError(f"Invalid node: {result.msg}") + + if result.status == _common.InferenceStatus.MISSING_INFO: + logger.warning("Missing info for node %s: %s", node.op_type, result.msg) + # Continue with partial inference or skip + if result.values is None: + return # Skip this node + + if result.status == _common.InferenceStatus.PARTIAL: + logger.info("Partial inference for node %s: %s", node.op_type, result.msg) + # Continue with partial results + + if result.values is None: + raise InferenceError("Inference returned no values") + + # Apply reconciliation policy + self._reconcile_outputs(node, result.values) + + def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer | None: + """Find a suitable inferrer for the given node. + + Args: + node: The node to find an inferrer for. + model: The model containing the node. + + Returns: + The best matching inferrer, or None if no suitable inferrer is found. + """ + key = (node.domain, node.op_type, node.overload) + inferrers = self._inferrer_registry.get(key, []) + + if not inferrers: + return None + + # Get model opset version for this domain + if node.version is not None: + opset_version = node.version + elif node.graph is not None and node.domain in node.graph.opset_imports: + opset_version = node.graph.opset_imports[node.domain] + else: + # Fallback to model-level opset import + if node.domain not in model.opset_imports: + raise InferenceError( + f"No opset import found for domain '{node.domain}' in model" + ) + opset_version = model.opset_imports[node.domain] + + # Find inferrers that support this opset version + suitable_inferrers = [ + inferrer for inferrer in inferrers if opset_version in inferrer.opsets + ] + + if not suitable_inferrers: + logger.warning( + "No inferrer supports opset %s for %s (domain: %s)", + opset_version, + node.op_type, + node.domain, + ) + return None + + # Return the first suitable inferrer (could be enhanced with priority logic) + return suitable_inferrers[0] + + def _reconcile_outputs(self, node: ir.Node, inferred_values: Sequence[ir.Value]) -> None: + """Reconcile inferred output values with existing node outputs. + + Args: + node: The node whose outputs to reconcile. + inferred_values: The inferred output values. + + Raises: + InferenceError: If reconciliation fails under strict policy. + """ + if len(inferred_values) != len(node.outputs): + raise InferenceError( + f"Inference returned {len(inferred_values)} values but node has " + f"{len(node.outputs)} outputs" + ) + + for i, (existing_output, inferred_value) in enumerate( + zip(node.outputs, inferred_values) + ): + if existing_output is None: + # No existing output - create new one + node.outputs[i] = inferred_value + continue + + # Reconcile based on policy + if self.reconciliation_policy == ReconciliationPolicy.OVERWRITE: + node.outputs[i] = inferred_value + + elif self.reconciliation_policy == ReconciliationPolicy.IGNORE: + # Keep existing output if it has shape/type info + if existing_output.shape is None and existing_output.type is None: + node.outputs[i] = inferred_value + # Otherwise keep existing + + elif self.reconciliation_policy == ReconciliationPolicy.RECONCILE: + reconciled_output = self._reconcile_value(existing_output, inferred_value) + node.outputs[i] = reconciled_output + + elif self.reconciliation_policy == ReconciliationPolicy.STRICT: + if not self._values_compatible(existing_output, inferred_value): + raise InferenceError( + f"Output {i} mismatch: existing {existing_output} vs " + f"inferred {inferred_value}" + ) + # Keep existing in strict mode if compatible + + def _reconcile_value(self, existing: ir.Value, inferred: ir.Value) -> ir.Value: + """Reconcile an existing value with an inferred value. + + Args: + existing: The existing value. + inferred: The inferred value. + + Returns: + The reconciled value. + """ + # Start with existing value + result_shape = existing.shape + result_type = existing.type + + # Use inferred shape if existing is None or less specific + if inferred.shape is not None: + if result_shape is None: + result_shape = inferred.shape + else: + # Try to merge shapes (prefer more specific) + result_shape = self._reconcile_shapes(result_shape, inferred.shape) + + # Use inferred type if existing is None + if inferred.type is not None and result_type is None: + result_type = inferred.type + + return ir.Value(shape=result_shape, type=result_type) + + def _reconcile_shapes(self, shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape: + """Reconcile two shapes by preferring more specific dimensions. + + Args: + shape1: First shape. + shape2: Second shape. + + Returns: + The reconciled shape. + """ + if len(shape1) != len(shape2): + logger.warning( + "Shape rank mismatch: %s vs %s. Using first shape.", len(shape1), len(shape2) + ) + return shape1 + + reconciled_dims = [] + for dim1, dim2 in zip(shape1.dims, shape2.dims): + # Prefer concrete dimensions over None/symbolic + if isinstance(dim1, int) and dim1 > 0: + reconciled_dims.append(dim1) + elif isinstance(dim2, int) and dim2 > 0: + reconciled_dims.append(dim2) + elif dim1 is not None: + reconciled_dims.append(dim1) + elif dim2 is not None: + reconciled_dims.append(dim2) + else: + reconciled_dims.append(None) + + return ir.Shape(reconciled_dims) + + def _values_compatible(self, value1: ir.Value, value2: ir.Value) -> bool: + """Check if two values are compatible (for strict mode). + + Args: + value1: First value. + value2: Second value. + + Returns: + True if the values are compatible. + """ + # Check shape compatibility + if value1.shape is not None and value2.shape is not None: + if not self._shapes_compatible(value1.shape, value2.shape): + return False + + # Check type compatibility + if value1.type is not None and value2.type is not None: + if value1.type != value2.type: + return False + + return True + + def _shapes_compatible(self, shape1: ir.Shape, shape2: ir.Shape) -> bool: + """Check if two shapes are compatible. + + Args: + shape1: First shape. + shape2: Second shape. + + Returns: + True if the shapes are compatible. + """ + if len(shape1) != len(shape2): + return False + + for dim1, dim2 in zip(shape1.dims, shape2.dims): + # None/symbolic dimensions are compatible with anything + if dim1 is None or dim2 is None: + continue + + # Both concrete - must match + if isinstance(dim1, int) and isinstance(dim2, int): + if dim1 != dim2: + return False + + # Symbolic dimensions - for now assume compatible + # Could be enhanced with symbolic expression comparison + + return True + + def get_inferrer_info(self) -> dict[str, int]: + """Get information about registered inferrers. + + Returns: + Dictionary mapping operation types to inferrer counts. + """ + info = {} + for (op_type, domain), inferrers in self._inferrer_registry.items(): + key = f"{op_type}:{domain}" if domain else op_type + info[key] = len(inferrers) + return info diff --git a/src/onnx_ir/_shape_type_inference/factory.py b/src/onnx_ir/_shape_type_inference/factory.py new file mode 100644 index 0000000..86c44bf --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/factory.py @@ -0,0 +1,217 @@ +"""Factory functions for creating inference engines with standard inferrers.""" + +from __future__ import annotations + +from onnx_ir._shape_type_inference._engine import ReconciliationPolicy, SymbolicInferenceEngine +from onnx_ir._shape_type_inference.ops.concat import ConcatInferrer +from onnx_ir._shape_type_inference.ops.constant import ConstantInferrer +from onnx_ir._shape_type_inference.ops.matmul import MatMulInferrer +from onnx_ir._shape_type_inference.ops.reshape import ReshapeInferrer +from onnx_ir._shape_type_inference.ops.squeeze import Squeeze12Inferrer, Squeeze13Inferrer +from onnx_ir._shape_type_inference.ops.standard_ops import BinaryInferrer, ElementwiseInferrer +from onnx_ir._shape_type_inference.ops.transpose import TransposeInferrer +from onnx_ir._shape_type_inference.ops.unsqueeze import ( + Unsqueeze12Inferrer, + Unsqueeze13Inferrer, +) + + +def create_standard_inference_engine( + reconciliation_policy: ReconciliationPolicy = ReconciliationPolicy.RECONCILE, +) -> SymbolicInferenceEngine: + """Create a SymbolicInferenceEngine with all standard operation inferrers. + + Args: + reconciliation_policy: Policy for handling conflicts between inferred and existing values. + + Returns: + A configured SymbolicInferenceEngine. + """ + inferrers = [] + + # Core tensor operations + inferrers.extend( + [ + ConstantInferrer(), + ReshapeInferrer(), + TransposeInferrer(), + # Squeeze/Unsqueeze with opset versions + Squeeze12Inferrer(), + Squeeze13Inferrer(), + Unsqueeze12Inferrer(), + Unsqueeze13Inferrer(), + ] + ) + + # Tensor manipulation + inferrers.extend( + [ + # GatherInferrer(), + # GatherElementsInferrer(), + # GatherNDInferrer(), + # ScatterElementsInferrer(), + # ExpandInferrer(), + # SliceInferrer(), + # SplitInferrer(), + ConcatInferrer(), + # PadInferrer(), + # TileInferrer(), + # WhereInferrer(), + # OneHotInferrer(), + # CompressInferrer(), + ] + ) + + # Mathematical operations + inferrers.extend( + [ + MatMulInferrer(), + # EinsumInferrer(), + # ReduceSumInferrer(), + # ReduceProdInferrer(), + ] + ) + + # Generation operations + inferrers.extend( + [ + # RangeInferrer(), + # ConstantOfShapeInferrer(), + # NonZeroInferrer(), + ] + ) + + # Pooling and convolution + inferrers.extend( + [ + # ConvInferrer(), + # AveragePoolInferrer(), + # MaxPoolInferrer(), + # BatchNormalizationInferrer(), + ] + ) + + # Sequence operations + inferrers.extend( + [ + # ConcatFromSequenceInferrer(), + # SplitToSequenceInferrer(), + # SequenceAtInferrer(), + # SequenceInsertInferrer(), + ] + ) + + # Control flow + inferrers.extend( + [ + # IfInferrer(), + # LoopInferrer(), + # ScanInferrer(), + ] + ) + + # ML-specific operations + inferrers.extend( + [ + # TopKInferrer(), + # NonMaxSuppressionInferrer(), + # SoftmaxCrossEntropyLossInferrer(), + # GroupNormInferrer(), + # GeluInferrer(), + ] + ) + + # Utility operations + inferrers.extend( + [ + # ArrayFeatureExtractorInferrer(), + # CategoryMapperInferrer(), + # ZipMapInferrer(), + # CumSumInferrer(), + # ResizeInferrer(), + ] + ) + + # Elementwise operations (covers many unary ops) + elementwise_ops = [ + "Abs", + "Acos", + "Acosh", + "Asin", + "Asinh", + "Atan", + "Atanh", + "Ceil", + "Cos", + "Cosh", + "Erf", + "Exp", + "Floor", + "Log", + "Neg", + "Reciprocal", + "Relu", + "Round", + "Sigmoid", + "Sign", + "Sin", + "Sinh", + "Sqrt", + "Tan", + "Tanh", + "Identity", + "IsInf", + "IsNaN", + ] + for op_type in elementwise_ops: + inferrers.append(ElementwiseInferrer(op_type)) + + # Binary operations (covers broadcasting ops) + binary_ops = [ + "Add", + "Sub", + "Mul", + "Div", + "Pow", + "Max", + "Min", + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + "And", + "Or", + "Xor", + ] + for op_type in binary_ops: + inferrers.append(BinaryInferrer(op_type)) + + return SymbolicInferenceEngine(inferrers, reconciliation_policy) + + +def create_minimal_inference_engine( + reconciliation_policy: ReconciliationPolicy = ReconciliationPolicy.RECONCILE, +) -> SymbolicInferenceEngine: + """Create a minimal SymbolicInferenceEngine with only essential inferrers. + + Args: + reconciliation_policy: Policy for handling conflicts between inferred and existing values. + + Returns: + A minimal SymbolicInferenceEngine. + """ + inferrers = [ + # Core essentials + ConstantInferrer(), + ReshapeInferrer(), + TransposeInferrer(), + MatMulInferrer(), + ConcatInferrer(), + # Basic elementwise and binary + ElementwiseInferrer("Identity"), + BinaryInferrer("Add"), + BinaryInferrer("Mul"), + ] + + return SymbolicInferenceEngine(inferrers, reconciliation_policy) diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py new file mode 100644 index 0000000..015ab8e --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -0,0 +1,94 @@ +"""Concat operation inferrer for ONNX IR nodes.""" + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ConcatInferrer(_common.NodeInferrer): + """Inferrer for Concat operations.""" + + def __init__(self) -> None: + """Initialize the Concat inferrer.""" + super().__init__("Concat", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET)) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Concat operations.""" + if len(node.inputs) < 1: + return _common.InferenceResult( + status="invalid_node", msg="Concat operation must have at least one input." + ) + if any(inp is None for inp in node.inputs): + return _common.InferenceResult( + status="missing_info", msg="Concat operation inputs cannot be None." + ) + if len(node.outputs) != 1: + return _common.InferenceResult( + status="invalid_node", + msg=f"Concat operation must have exactly one output, got {len(node.outputs)}.", + ) + + # Get axis attribute + axis = node.attributes.get_int("axis") + if axis is None: + return _common.InferenceResult( + status="invalid_node", msg="Concat operation requires axis attribute." + ) + + # Get first input shape as base + first_shape = node.inputs[0].shape + if first_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Concat input shapes cannot be None." + ) + first_type = node.inputs[0].type + + rank = len(first_shape) + if rank == 0: + return _common.InferenceResult( + status="invalid_node", msg="Concat inputs cannot be scalars." + ) + + # Handle negative axis + if axis < 0: + axis += rank + + if axis < 0 or axis >= rank: + return _common.InferenceResult( + status="invalid_node", + msg=f"Concat axis {axis} is out of bounds for rank {rank}.", + ) + + # Check that all inputs have compatible shapes + output_dims = list(first_shape) + concat_dim_size = _common.get_expr(first_shape, axis) + + for i, inp in enumerate(node.inputs[1:], 1): + if inp is None: + return _common.InferenceResult( + status="missing_info", msg=f"Input {i} cannot be None." + ) + if inp.shape is None: + return _common.InferenceResult( + status="missing_info", msg=f"Input {i} shape cannot be None." + ) + + input_shape = inp.shape + if len(input_shape) != rank: + return _common.InferenceResult( + status="invalid_node", + msg=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}.", + ) + + # TODO(justinchuby): Check non-concat dimensions are compatible + concat_dim_size = concat_dim_size + _common.get_expr(input_shape, axis) + if inp.type != first_type: + return _common.InferenceResult( + status="invalid_node", + msg=f"Input {i} type {inp.type} does not match first input type {first_type}.", + ) + + # Set the concat dimension in output shape + output_dims[axis] = concat_dim_size + return _common.InferenceResult( + values=(ir.Value(shape=ir.Shape(output_dims), type=first_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/constant.py b/src/onnx_ir/_shape_type_inference/ops/constant.py new file mode 100644 index 0000000..d9338ec --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/constant.py @@ -0,0 +1,36 @@ +"""Constant operation inferrer for ONNX IR nodes.""" + +from __future__ import annotations + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ConstantInferrer(_common.NodeInferrer): + """Inferrer for Constant operations.""" + + def __init__(self) -> None: + """Initialize the Constant inferrer.""" + super().__init__( + "Constant", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + ) + + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Constant operations.""" + assert node.inputs[0] is not None + tensor = ir.convenience.get_const_tensor(node.inputs[0]) + if tensor is None: + return _common.InferenceResult( + status="missing_info", msg="Constant tensor cannot be obtained." + ) + + # Create shape from the tensor dimensions + output_shape = ir.Shape(tensor.shape) + + # Get the data type from the tensor + output_type = ir.TensorType(tensor.dtype) + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py new file mode 100644 index 0000000..45d31ea --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -0,0 +1,72 @@ +"""MatMul operation inferrer for ONNX IR nodes.""" + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common +from onnx_ir._shape_type_inference.ops.standard_ops import broadcast_shapes_bidirectional + + +class MatMulInferrer(_common.NodeInferrer): + """Inferrer for MatMul operations.""" + + def __init__(self) -> None: + """Initialize the MatMul inferrer.""" + super().__init__("MatMul", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET)) + + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for MatMul operations.""" + assert node.inputs[0] is not None and node.inputs[1] is not None + + lhs_shape = node.inputs[0].shape + rhs_shape = node.inputs[1].shape + if lhs_shape is None or rhs_shape is None: + return _common.InferenceResult( + status="missing_info", msg="MatMul input shapes cannot be None." + ) + + lhs_rank = len(lhs_shape) + rhs_rank = len(rhs_shape) + + if lhs_rank == 0 or rhs_rank == 0: + return _common.InferenceResult( + status="invalid_node", msg="MatMul inputs cannot be scalars." + ) + + # Compute output shape based on matrix multiplication rules + if lhs_rank == 1 and rhs_rank == 1: + # Vector dot product: (n,) x (n,) -> scalar + output_shape = ir.Shape([]) + elif lhs_rank == 1: + # Matrix-vector: (n,) x (..., n, k) -> (..., k) + output_dims = [*rhs_shape[:-2], rhs_shape[-1]] + output_shape = ir.Shape(output_dims) + elif rhs_rank == 1: + # Vector-matrix: (..., m, n) x (n,) -> (..., m) + output_shape = ir.Shape(lhs_shape[:-1]) + else: + # Matrix-matrix: (..., m, n) x (..., n, k) -> (..., m, k) + # Broadcast batch dimensions + lhs_batch = lhs_shape[:-2] + rhs_batch = rhs_shape[:-2] + if lhs_batch and rhs_batch: + # TODO(justinchuby): Ensure this is correct + batch_shape = broadcast_shapes_bidirectional( + ir.Shape(lhs_batch), ir.Shape(rhs_batch) + ) + output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]] + output_shape = ir.Shape(output_dims) + elif lhs_batch: + output_dims = [*lhs_batch, lhs_shape[-2], rhs_shape[-1]] + output_shape = ir.Shape(output_dims) + elif rhs_batch: + output_dims = [*rhs_batch, lhs_shape[-2], rhs_shape[-1]] + output_shape = ir.Shape(output_dims) + else: + output_dims = [lhs_shape[-2], rhs_shape[-1]] + output_shape = ir.Shape(output_dims) + + output_type = node.inputs[0].type + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/reshape.py b/src/onnx_ir/_shape_type_inference/ops/reshape.py new file mode 100644 index 0000000..80659b5 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/reshape.py @@ -0,0 +1,92 @@ +"""Reshape operation inferrer for ONNX IR nodes.""" + +import sympy + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ReshapeInferrer(_common.NodeInferrer): + """Inferrer for Reshape operations.""" + + def __init__(self) -> None: + super().__init__( + "Reshape", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + ) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Reshape operations.""" + if len(node.inputs) != 2: + return _common.InferenceResult( + status="invalid_node", + msg=f"Reshape operation must have exactly two inputs, got {len(node.inputs)}.", + ) + if node.inputs[0] is None or node.inputs[1] is None: + return _common.InferenceResult( + status="missing_info", msg="Reshape operation inputs cannot be None." + ) + if len(node.outputs) != 1: + return _common.InferenceResult( + status="invalid_node", + msg=f"Reshape operation must have exactly one output, got {len(node.outputs)}.", + ) + + input_shape = node.inputs[0].shape + shape_input = node.inputs[1] + + if input_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Reshape input shape cannot be None." + ) + + # Try to get the shape values from the second input + # For symbolic inference, we may not have concrete values + shape = ir.convenience.get_const_tensor(shape_input) + if shape is None: + return _common.InferenceResult( + status="missing_info", msg="Reshape shape input is not known." + ) + + shape_values = shape.numpy().tolist() + + # Calculate total elements in input + total_elements = sympy.Integer(1) + for dim in range(input_shape.rank()): + total_elements *= _common.get_expr(input_shape, dim) + + # Process shape values + output_dims = [] + deferred_dim_idx = -1 + non_deferred_size = sympy.Integer(1) + + for i, dim_value in enumerate(shape_values): + if dim_value == -1: + if deferred_dim_idx != -1: + return _common.InferenceResult( + status="invalid_node", msg="Reshape can have at most one -1 dimension." + ) + deferred_dim_idx = i + output_dims.append(None) # Placeholder + elif dim_value == 0: + # Copy from input shape + if i >= len(input_shape): + return _common.InferenceResult( + status="invalid_node", + msg=f"Cannot copy dimension {i} from input shape of rank {len(input_shape)}.", + ) + dim_expr = _common.get_expr(input_shape, i) + output_dims.append(dim_expr) + non_deferred_size *= dim_expr + else: + output_dims.append(dim_value) + non_deferred_size *= sympy.Integer(dim_value) + + # Calculate deferred dimension + if deferred_dim_idx != -1: + deferred_dim = total_elements // non_deferred_size + output_dims[deferred_dim_idx] = deferred_dim + + # Create output shape + return _common.InferenceResult( + values=(ir.Value(shape=ir.Shape(output_dims), type=node.inputs[0].type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py new file mode 100644 index 0000000..3494e56 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -0,0 +1,133 @@ +"""Squeeze operation inferrer for ONNX IR nodes.""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + +logger = logging.getLogger(__name__) + + +def _compute_output_shape_no_axes(input_shape: ir.Shape) -> ir.Shape: + """Compute output shape when no axes are specified.""" + output_dims = [] + for dim in input_shape: + # For symbolic dimensions, we assume they are not 1 + # Only squeeze literal 1s + if isinstance(dim, int): + if dim == 1: + continue # Skip dimension of size 1 + else: + output_dims.append(dim) + else: + logger.warning( + "Squeeze operation has symbolic dimension %s, assuming it is not 1.", dim + ) + output_dims.append(dim) + return ir.Shape(output_dims) + + +def _normalize_axes(axes: Sequence[int], rank: int) -> set[int]: + """Normalize axes to be within the valid range for the given rank.""" + normalized_axes = set() + for axis in axes: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + raise ValueError(f"Squeeze axis {axis} is out of bounds for rank {rank}.") + normalized_axes.add(axis) + return normalized_axes + + +def _compute_output_shape_with_axes(input_shape: ir.Shape, axes: set[int]) -> ir.Shape: + """Compute output shape when axes are specified.""" + output_dims = [dim for i, dim in enumerate(input_shape) if i not in axes] + return ir.Shape(output_dims) + + +class Squeeze12Inferrer(_common.NodeInferrer): + """Inferrer for Squeeze-12 and lower. + + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: + """Initialize the Squeeze inferrer.""" + super().__init__("Squeeze", opsets=range(13)) + + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Squeeze operations.""" + input = node.inputs[0] + assert input is not None + input_shape = input.shape + if input_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Squeeze input shape is not known." + ) + + rank = len(input_shape) + + # Get axes to squeeze + axes = node.attributes.get_ints("axes") + + if axes is None: + output_shape = _compute_output_shape_no_axes(input_shape) + else: + try: + axes = _normalize_axes(axes, rank) + except ValueError as e: + return _common.InferenceResult(status="invalid_node", msg=str(e)) + output_shape = _compute_output_shape_with_axes(input_shape, axes) + return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input.type),)) + + +class Squeeze13Inferrer(_common.NodeInferrer): + """Inferrer for Squeeze-13 and higher. + + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: + """Initialize the Squeeze inferrer.""" + super().__init__("Squeeze", opsets=range(14, _common.MAX_SUPPORTED_OPSET)) + + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Squeeze operations.""" + assert node.inputs[0] is not None + assert node.inputs[1] is not None + + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Squeeze input shape is not known." + ) + + rank = len(input_shape) + + axes_tensor = ir.convenience.get_const_tensor(node.inputs[1]) + if axes_tensor is not None: + try: + axes = _normalize_axes(axes_tensor.numpy().tolist(), rank) + except ValueError as e: + return _common.InferenceResult(status="invalid_node", msg=str(e)) + output_shape = _compute_output_shape_with_axes(input_shape, axes) + else: + axes_shape = node.inputs[1].shape + if axes_shape is None or axes_shape.is_dynamic(): + return _common.InferenceResult( + status="missing_info", + msg="Squeeze axes input shape is not known or is dynamic", + ) + removed_axes_count = axes_shape[0] + assert isinstance(removed_axes_count, int) + output_shape = ir.Shape([None] * (rank - removed_axes_count)) + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py new file mode 100644 index 0000000..90b4201 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -0,0 +1,108 @@ +"""Standard Inferrers for ONNX IR nodes.""" + +from __future__ import annotations + +from collections.abc import Collection + +import sympy + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ElementwiseInferrer(_common.NodeInferrer): + """Base class for elementwise operation inferrers.""" + + def __init__(self, op_type: str, opsets: Collection[int] | None = None) -> None: + """Initialize the elementwise inferrer with the operation type.""" + if opsets is None: + opsets = _common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + super().__init__(op_type, opsets=opsets) + + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for elementwise operations.""" + assert node.inputs[0] is not None + return _common.InferenceResult( + (ir.Value(shape=node.inputs[0].shape, type=node.inputs[0].type),) + ) + + +def broadcast_shapes_bidirectional(shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape: + """Broadcast two shapes bidirectionally. + + Args: + shape1: The first shape to broadcast. + shape2: The second shape to broadcast. + + Returns: + A new shape that is the result of broadcasting both shapes. + """ + rank1 = len(shape1) + rank2 = len(shape2) + new_rank = max(rank1, rank2) + new_dims = [] + + for i in range(new_rank): + dim1_idx = rank1 - 1 - i + dim2_idx = rank2 - 1 - i + + # Get expressions for dimensions + dim1_expr = _common.get_expr(shape1, dim1_idx) if i < rank1 else sympy.Integer(1) + dim2_expr = _common.get_expr(shape2, dim2_idx) if i < rank2 else sympy.Integer(1) + + # Broadcasting rules + if dim1_expr == 1: + new_dim_expr = dim2_expr + elif dim2_expr == 1: + new_dim_expr = dim1_expr + elif dim1_expr == dim2_expr: + new_dim_expr = dim1_expr + else: + # Incompatible dimensions - this should be caught at runtime + # For symbolic inference, we assume they can be broadcast + new_dim_expr = sympy.Max(dim1_expr, dim2_expr) + + # Add to the front to maintain right-to-left processing order + new_dims.insert(0, new_dim_expr) + + # Create new shape directly + return ir.Shape(new_dims) + + +class BinaryInferrer(_common.NodeInferrer): + """Base class for binary operation inferrers.""" + + def __init__(self, op_type: str) -> None: + """Initialize the binary inferrer with the operation type.""" + super().__init__(op_type, opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET)) + + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for binary operations.""" + assert node.inputs[0] is not None + assert node.inputs[1] is not None + first_type = node.inputs[0].type + second_type = node.inputs[1].type + if first_type is not None and second_type is not None and first_type != second_type: + return _common.InferenceResult( + status="invalid_node", + msg=f"Input types do not match: {first_type} vs {second_type}.", + ) + + # Broadcast the input shapes + first_shape = node.inputs[0].shape + second_shape = node.inputs[1].shape + if first_shape is None or second_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Input shapes cannot be None." + ) + + output_shape = broadcast_shapes_bidirectional(first_shape, second_shape) + output_type = first_type if first_type is not None else second_type + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py new file mode 100644 index 0000000..7f0a93a --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -0,0 +1,67 @@ +"""Transpose operation inferrer for ONNX IR nodes.""" + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class TransposeInferrer(_common.NodeInferrer): + """Inferrer for Transpose operations.""" + + def __init__(self) -> None: + """Initialize the Transpose inferrer.""" + super().__init__( + "Transpose", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + ) + + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Transpose operations.""" + assert node.inputs[0] is not None + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Transpose input shape cannot be None." + ) + + rank = len(input_shape) + + # Get permutation from attributes + perm = node.attributes.get_ints("perm") + + # Default permutation is reversed order + if perm is None: + perm = list(reversed(range(rank))) + + # Validate permutation + if len(perm) != rank: + return _common.InferenceResult( + status="invalid_node", + msg=f"Permutation length {len(perm)} does not match input rank {rank}.", + ) + + if sorted(perm) != list(range(rank)): + return _common.InferenceResult( + status="invalid_node", + msg=f"Invalid permutation {perm}. Must be a permutation of [0, 1, ..., {rank - 1}].", + ) + + # Apply permutation to create output shape + output_dims = [] + for axis in perm: + # Handle negative axis + if axis < 0: + axis += rank + + if axis < 0 or axis >= rank: + return _common.InferenceResult( + status="invalid_node", + msg=f"Permutation axis {axis} is out of bounds for rank {rank}.", + ) + + # Copy dimension from input to output according to permutation + output_dims.append(input_shape[axis]) + + return _common.InferenceResult( + values=(ir.Value(shape=ir.Shape(output_dims), type=node.inputs[0].type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py new file mode 100644 index 0000000..4d8aa7e --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -0,0 +1,154 @@ +"""Unsqueeze operation inferrer for ONNX IR nodes.""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + +logger = logging.getLogger(__name__) + + +def _normalize_axes(axes: Sequence[int], output_rank: int) -> set[int]: + """Normalize axes to be within the valid range for the given output rank.""" + normalized_axes = set() + for axis in axes: + if axis < 0: + axis += output_rank + if axis < 0 or axis >= output_rank: + raise ValueError( + f"Unsqueeze axis {axis} is out of bounds for output rank {output_rank}." + ) + normalized_axes.add(axis) + + # Check for duplicate axes + if len(normalized_axes) != len(axes): + raise ValueError("Unsqueeze axes must be unique.") + + return normalized_axes + + +def _compute_output_shape(input_shape: ir.Shape, axes: set[int]) -> ir.Shape: + """Compute output shape by inserting 1s at specified axes.""" + input_rank = len(input_shape) + output_rank = input_rank + len(axes) + + output_dims = [] + input_axis = 0 + + for output_axis in range(output_rank): + if output_axis in axes: + # Insert dimension of size 1 + output_dims.append(1) + else: + # Copy dimension from input + output_dims.append(input_shape[input_axis]) + input_axis += 1 + + return ir.Shape(output_dims) + + +class Unsqueeze12Inferrer(_common.NodeInferrer): + """Inferrer for Unsqueeze-12 and lower. + + In these versions, axes are provided as an attribute. + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: + """Initialize the Unsqueeze inferrer.""" + super().__init__("Unsqueeze", opsets=range(13)) + + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Unsqueeze operations.""" + input = node.inputs[0] + assert input is not None + input_shape = input.shape + if input_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Unsqueeze input shape is not known." + ) + + input_rank = len(input_shape) + + # Get axes to unsqueeze from attributes + axes = node.attributes.get_ints("axes") + if axes is None: + return _common.InferenceResult( + status="invalid_node", msg="Unsqueeze operation requires axes attribute." + ) + + output_rank = input_rank + len(axes) + + try: + normalized_axes = _normalize_axes(axes, output_rank) + except ValueError as e: + return _common.InferenceResult(status="invalid_node", msg=str(e)) + + output_shape = _compute_output_shape(input_shape, normalized_axes) + return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input.type),)) + + +class Unsqueeze13Inferrer(_common.NodeInferrer): + """Inferrer for Unsqueeze-13 and higher. + + In these versions, axes are provided as a second input tensor. + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: + """Initialize the Unsqueeze inferrer.""" + super().__init__("Unsqueeze", opsets=range(13, _common.MAX_SUPPORTED_OPSET)) + + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Unsqueeze operations.""" + assert node.inputs[0] is not None + assert node.inputs[1] is not None + + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult( + status="missing_info", msg="Unsqueeze input shape is not known." + ) + + input_rank = len(input_shape) + + axes_tensor = ir.convenience.get_const_tensor(node.inputs[1]) + if axes_tensor is not None: + axes = axes_tensor.numpy().tolist() + if not isinstance(axes, list): + axes = [axes] + + output_rank = input_rank + len(axes) + + try: + normalized_axes = _normalize_axes(axes, output_rank) + except ValueError as e: + return _common.InferenceResult(status="invalid_node", msg=str(e)) + + output_shape = _compute_output_shape(input_shape, normalized_axes) + else: + # Handle case where axes tensor is not constant + axes_shape = node.inputs[1].shape + if axes_shape is None or axes_shape.is_dynamic(): + return _common.InferenceResult( + status="missing_info", + msg="Unsqueeze axes input shape is not known or is dynamic", + ) + + # We know the number of axes to insert but not their positions + added_axes_count = axes_shape[0] + assert isinstance(added_axes_count, int) + output_rank = input_rank + added_axes_count + # Create output shape with unknown dimensions + output_shape = ir.Shape([None] * output_rank) + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + )