Skip to content

[WIP] Create symbolic type/shape inference logic #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1adca71
Scaffold
justinchuby Jun 29, 2025
f48537b
Add support for expr in SymbolicDim
justinchuby Jun 29, 2025
7464af1
wip
justinchuby Jun 29, 2025
27ae0ca
wip
justinchuby Jun 29, 2025
1afe64a
Create NodeInferencer
justinchuby Jun 29, 2025
78ad6e0
inference_common
justinchuby Jun 29, 2025
5aa2df7
Update shapes
justinchuby Jun 29, 2025
dbc3593
update
justinchuby Jun 29, 2025
b9f0528
Claude - add sympy import
justinchuby Jun 30, 2025
c9a35b7
Claude and lint
justinchuby Jun 30, 2025
65e3dd2
concat
justinchuby Jun 30, 2025
7960770
Update _maybe_convert_to_symbolic_dim
justinchuby Jun 30, 2025
a7704c5
reshape
justinchuby Jun 30, 2025
922a597
Update the way dim is set
justinchuby Jun 30, 2025
9183848
Simplify
justinchuby Jun 30, 2025
9300aba
Update
justinchuby Jun 30, 2025
8747a93
Handle unknown dims
justinchuby Jun 30, 2025
92049c4
Simplify
justinchuby Jun 30, 2025
720845e
Create inclusive range
justinchuby Jun 30, 2025
bae78ab
WIP inference engine
justinchuby Jun 30, 2025
a77f487
Create readme
justinchuby Jun 30, 2025
6686457
Result
justinchuby Jun 30, 2025
3207e84
Summary of Complete Refactoring
justinchuby Jun 30, 2025
a572145
lint
justinchuby Jun 30, 2025
11f8958
Removes unused shape inference code
justinchuby Jun 30, 2025
f3c70da
Summary of Shape Simplifications
justinchuby Jun 30, 2025
4b6d80d
Create factory
justinchuby Jun 30, 2025
e03733b
Use Enum
justinchuby Jun 30, 2025
5a34891
Update logging calls
justinchuby Jun 30, 2025
ab09107
Working on engine
justinchuby Jun 30, 2025
9256233
todo
justinchuby Jun 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "sympy"]

[project.urls]
Homepage = "https://onnx.ai/ir-py"
Expand Down
37 changes: 31 additions & 6 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

import ml_dtypes
import numpy as np
import sympy
import sympy.utilities.misc
from typing_extensions import TypeIs

import onnx_ir
Expand Down Expand Up @@ -1115,13 +1117,14 @@
It is immutable and can be compared or hashed.
"""

__slots__ = ("_value",)
__slots__ = ("_expr", "_value")

def __init__(self, value: str | None) -> None:
def __init__(self, value: str | None, /, expr: sympy.Expr | None = None) -> None:
"""Initialize a symbolic dimension.

Args:
value: The value of the dimension. It should not be an int.
expr: An optional sympy expression representing the dimension.

Raises:
TypeError: If value is an int.
Expand All @@ -1132,6 +1135,7 @@
"If you are creating a Shape, use int directly instead of SymbolicDim."
)
self._value = value
self._expr: sympy.Expr | None = expr

def __eq__(self, other: object) -> bool:
"""Check equality with another SymbolicDim or string/None."""
Expand All @@ -1148,11 +1152,24 @@
"""The value of the symbolic dimension (string or None)."""
return self._value

@property
def expr(self) -> sympy.Expr | None:
"""The sympy expression representing the symbolic dimension."""
return self._expr

Check warning on line 1158 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1158

Added line #L1158 was not covered by tests

def __str__(self) -> str:
return f"{self._value}"
if self._value is not None:
return str(self._value)

Check warning on line 1162 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1162

Added line #L1162 was not covered by tests
if self._expr is not None:
return str(self._expr)
return "?"

Check warning on line 1165 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1164-L1165

Added lines #L1164 - L1165 were not covered by tests

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._value})"
if self._expr is not None:
expr_text = f", expr={self._expr!r}"

Check warning on line 1169 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1169

Added line #L1169 was not covered by tests
else:
expr_text = ""
return f"{self.__class__.__name__}({self._value}{expr_text})"

Check warning on line 1172 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1171-L1172

Added lines #L1171 - L1172 were not covered by tests


def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
Expand Down Expand Up @@ -1190,10 +1207,16 @@
return SymbolicDim(dim)
if _is_int_compatible(dim):
return int(dim)
if isinstance(dim, sympy.Expr):
# If the dimension is a sympy expression, we create a SymbolicDim with it
expr = sympy.sympify(dim)

Check warning on line 1212 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1212

Added line #L1212 was not covered by tests
if expr.is_integer:
return sympy.utilities.misc.as_int(expr)
return SymbolicDim(str(expr), expr=sympy.sympify(expr))

Check warning on line 1215 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1214-L1215

Added lines #L1214 - L1215 were not covered by tests
if isinstance(dim, SymbolicDim):
return dim
raise TypeError(
f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
f"Expected int, str, sympy.Expr, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
)


Expand Down Expand Up @@ -1334,7 +1357,9 @@
def __getitem__(self, index):
return tuple(self._dims)[index]

def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None:
def __setitem__(
self, index: int, value: int | SymbolicDim | str | sympy.Expr | None
) -> None:
"""Set the dimension at the index.

Args:
Expand Down
262 changes: 262 additions & 0 deletions src/onnx_ir/_shape_type_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Symbolic Shape and Type Inference

This module provides symbolic shape and type inference for ONNX IR models, enabling compile-time analysis of tensor shapes and types with support for symbolic dimensions.

## Overview

The inference engine performs forward propagation through ONNX models to determine output shapes and types based on input specifications. It supports symbolic dimensions using SymPy expressions, allowing for dynamic shape analysis.

## Key Components

### Core Classes

- **`SymbolicInferenceEngine`**: Main orchestrator that processes models and applies inference
- **`NodeInferrer`**: Base class for operation-specific inference logic
- **`InferenceResult`**: Container for inference results with status and optional message
- **`InferenceStatus`**: Enum for inference operation status (SUCCESS, PARTIAL, MISSING_INFO, INVALID_NODE)

### Reconciliation Policies

The engine supports different strategies for handling conflicts between inferred and existing values:

- **`OVERWRITE`**: Always use inferred values
- **`IGNORE`**: Keep existing values if they exist
- **`RECONCILE`**: Merge inferred and existing values intelligently
- **`STRICT`**: Fail if inferred values don't match existing ones

### Inference Status System

The `InferenceResult` uses a status-based approach for granular error handling:

- **`SUCCESS`**: Complete inference successful with full shape/type information
- **`PARTIAL`**: Partial information available (e.g., type only, rank only)
- **`MISSING_INFO`**: Missing required input information (shapes, types)
- **`INVALID_NODE`**: Node is invalid or malformed

```python
# Example usage in inferrers
def infer(self, node: ir.Node) -> InferenceResult:
if node.inputs[0].shape is None:
return InferenceResult(
status="missing_info",
msg="Input shape is required"
)

# Partial inference - only type available
if can_infer_type_only():
return InferenceResult(
values=[ir.Value(type=inferred_type)],
status="partial",
msg="Shape unavailable, type only"
)

# Full inference (status defaults to "success")
return InferenceResult(values=[full_value])
```

## Architecture

```text
SymbolicInferenceEngine
├── NodeInferrer Registry (by op_type + domain)
├── Opset Version Matching
└── Reconciliation Logic

NodeInferrer Implementations
├── ElementwiseInferrer (unary operations)
├── BinaryInferrer (broadcasting operations)
└── Specialized Inferrers (50+ operations)
```

## Inferrer Selection

The engine selects the appropriate inferrer using a two-stage process:

1. **Registry Lookup**: Inferrers are registered by `(op_type, domain)` key
2. **Opset Matching**: Among matching inferrers, select those supporting the model's opset version

```python
# Example: For a Squeeze node with opset 14
# - Multiple Squeeze inferrers may be registered
# - Engine selects Squeeze13Inferrer (supports opset 13-23)
# - Ignores Squeeze12Inferrer (supports opset 1-12)
```

## Symbolic Dimensions

The system stores symbolic expressions in `ir.SymbolicDim` objects:

```python
class SymbolicDim:
value: str | None # String identifier (e.g., "N", "batch_size")
expr: sympy.Expr | None # SymPy expression for computed dimensions
```

Dimensions are accessed via `get_expr()` which converts to SymPy expressions:

- `SymbolicDim(value="N")` → `sympy.Symbol("N")`
- `SymbolicDim(expr=N*2)` → `N*2` (SymPy expression)
- Integer dimensions → `sympy.Integer(value)`

## NodeInferrer Design Decisions

### Base Class Structure
The `NodeInferrer` abstract base class enforces a consistent interface:

```python
class NodeInferrer(abc.ABC):
def __init__(self, op_type: str, opsets: Collection[int], domain: str = ""):
# Store operation metadata for registry matching

@abc.abstractmethod
def infer(self, node: ir.Node) -> InferenceResult:
# Operation-specific inference logic
```

### Design Rationale

1. **Single Responsibility**: Each inferrer handles exactly one operation type
2. **Opset Awareness**: Inferrers declare supported ONNX opset versions for compatibility
3. **Domain Support**: Enables custom domains beyond standard ONNX operators
4. **Validation Decorators**: `@requires_non_none_inputs(n)` and `@requires_outputs(n)` provide consistent input validation
5. **Failure Handling**: Return `InferenceResult` with either `values` or `failure` for graceful error handling

### Inheritance Patterns

- **ElementwiseInferrer**: Template for unary operations that preserve input shape/type
- **BinaryInferrer**: Template for binary operations with broadcasting logic
- **Specialized Inferrers**: Custom logic for complex operations (Conv, Reshape, etc.)

## Usage

### Basic Usage

```python
from onnx_ir._shape_type_inference.factory import create_standard_inference_engine
from onnx_ir._shape_type_inference import ReconciliationPolicy

# Create engine with all standard operations
engine = create_standard_inference_engine(ReconciliationPolicy.RECONCILE)

# Perform inference on a model
engine.infer_model(model)
```

### Custom Engine

```python
from onnx_ir._shape_type_inference import SymbolicInferenceEngine
from onnx_ir._shape_type_inference.ops.matmul import MatMulInferrer
from onnx_ir._shape_type_inference.ops.standard_ops import BinaryInferrer

# Create custom engine with specific operations
inferrers = [
MatMulInferrer(),
BinaryInferrer("Add"),
BinaryInferrer("Mul"),
]

engine = SymbolicInferenceEngine(inferrers, ReconciliationPolicy.STRICT)
```

## Opset Version Support

Each inferrer specifies supported ONNX opset versions to handle API changes:

```python
class Squeeze12Inferrer(NodeInferrer):
def __init__(self):
super().__init__("Squeeze", opsets=range(1, 13))

class Squeeze13Inferrer(NodeInferrer):
def __init__(self):
super().__init__("Squeeze", opsets=range(13, 24))
```

## Error Handling

The engine provides comprehensive error handling:

- **Validation Errors**: Invalid input/output counts, missing shapes
- **Type Mismatches**: Incompatible input types for binary operations
- **Inference Failures**: Operation-specific inference errors
- **Reconciliation Conflicts**: Value mismatches in strict mode

## Factory Functions

Pre-configured engines for common use cases:

- **`create_standard_inference_engine()`**: Full operation coverage (50+ ops)
- **`create_minimal_inference_engine()`**: Essential operations only

## Subgraphs and ONNX Functions

### Design Approach

#### Subgraph Pre-Processing Strategy

The engine uses a **subgraph-first** approach for cleaner separation of concerns:

1. **Pre-Processing Phase**: Before running node inference, detect and recursively process all subgraphs
2. **Bottom-Up Inference**: Subgraphs are fully inferred before their parent nodes
3. **Simplified Node Logic**: Control flow inferrers (If, Loop, Scan) can assume subgraph shapes are already available

```python
class SymbolicInferenceEngine:
def _infer_node(self, node: ir.Node, model: ir.Model) -> None:
# First: recursively infer any subgraphs
for attr in node.attributes:
if isinstance(attr.value, ir.Graph):
self._infer_subgraph(attr.value, model)

# Then: run node-specific inference with subgraphs already processed
inferrer = self._find_inferrer(node, model)
result = inferrer.infer(node) # Subgraph shapes already available
```

#### ONNX Function Support

Functions are handled through **automatic expansion** without custom inferrer logic:

1. **Function Context**: Engine maintains intermediate value mappings during function execution
2. **Transparent Expansion**: Function calls are expanded inline and processed like regular subgraphs
3. **No Custom Logic**: Users don't implement function-specific inferrers - the engine handles it automatically

```python
class SymbolicInferenceEngine:
def _infer_function_call(self, node: ir.Node, function: ir.Function) -> InferenceResult:
# Create isolated context for function execution
function_context = self._create_function_context(node.inputs, function)

# Process function body as a subgraph
for func_node in function.nodes:
self._infer_node_in_context(func_node, function_context)

# Map function outputs back to caller node
return self._extract_function_outputs(function_context, function.outputs)
```

### Key Benefits

1. **Cleaner Separation**: Subgraph inference is handled by the engine, not individual inferrers
2. **Automatic Function Support**: No need to implement custom logic for each function
3. **Simplified Debugging**: Each phase (subgraphs → nodes) can be debugged independently
4. **Consistent Context**: Function calls maintain proper variable scoping and type consistency

## Extension Points

To add support for new operations:

1. Create a new inferrer class inheriting from `NodeInferrer`
2. Implement the `infer()` method with operation-specific logic
3. Register with the engine or add to factory functions

```python
class CustomOpInferrer(NodeInferrer):
def __init__(self):
super().__init__("CustomOp", opsets=range(1, 24), domain="custom_domain")

def infer(self, node: ir.Node) -> InferenceResult:
# Custom inference logic
return InferenceResult(values=[result_value])
```
20 changes: 20 additions & 0 deletions src/onnx_ir/_shape_type_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Symbolic shape and type inference for ONNX IR."""

__all__ = [
"SymbolicInferenceEngine",
"InferenceError",
"NodeInferrer",
"InferenceResult",
"InferenceStatus",
]


from onnx_ir._shape_type_inference._common import (
InferenceResult,
InferenceStatus,
NodeInferrer,
)
from onnx_ir._shape_type_inference._engine import (
InferenceError,
SymbolicInferenceEngine,
)
Loading
Loading