Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/qonnx/custom_op/channels_last/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Importing registers CustomOps in qonnx.custom_op.channels_last domain
from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization
from qonnx.custom_op.channels_last.conv import Conv
from qonnx.custom_op.channels_last.max_pool import MaxPool

custom_op = dict()

custom_op["Conv"] = Conv
custom_op["MaxPool"] = MaxPool
custom_op["BatchNormalization"] = BatchNormalization
# Legacy dictionary for backward compatibility
custom_op = {
"Conv": Conv,
"MaxPool": MaxPool,
"BatchNormalization": BatchNormalization,
}
31 changes: 17 additions & 14 deletions src/qonnx/custom_op/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Importing registers CustomOps in qonnx.custom_op.general domain
from qonnx.custom_op.general.bipolar_quant import BipolarQuant
from qonnx.custom_op.general.debugmarker import DebugMarker
from qonnx.custom_op.general.floatquant import FloatQuant
Expand All @@ -35,20 +36,22 @@
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from qonnx.custom_op.general.multithreshold import MultiThreshold
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
from qonnx.custom_op.general.quant import Quant
from qonnx.custom_op.general.trunc import Trunc
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul

custom_op = dict()

custom_op["DebugMarker"] = DebugMarker
custom_op["QuantAvgPool2d"] = QuantAvgPool2d
custom_op["MaxPoolNHWC"] = MaxPoolNHWC
custom_op["GenericPartition"] = GenericPartition
custom_op["MultiThreshold"] = MultiThreshold
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
custom_op["Im2Col"] = Im2Col
custom_op["IntQuant"] = IntQuant
custom_op["Quant"] = IntQuant
custom_op["Trunc"] = Trunc
custom_op["BipolarQuant"] = BipolarQuant
custom_op["FloatQuant"] = FloatQuant
# Legacy dictionary for backward compatibility
custom_op = {
"DebugMarker": DebugMarker,
"QuantAvgPool2d": QuantAvgPool2d,
"MaxPoolNHWC": MaxPoolNHWC,
"GenericPartition": GenericPartition,
"MultiThreshold": MultiThreshold,
"XnorPopcountMatMul": XnorPopcountMatMul,
"Im2Col": Im2Col,
"IntQuant": IntQuant,
"Quant": IntQuant, # Alias
"Trunc": Trunc,
"BipolarQuant": BipolarQuant,
"FloatQuant": FloatQuant,
}
131 changes: 116 additions & 15 deletions src/qonnx/custom_op/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,125 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import importlib
import inspect
from threading import RLock
from typing import Dict, Tuple, Type

from qonnx.custom_op.base import CustomOp
from qonnx.util.basic import get_preferred_onnx_opset

# Registry keyed by original ONNX domain: (domain, op_type) -> CustomOp class
_OP_REGISTRY: Dict[Tuple[str, str], Type[CustomOp]] = {}

_REGISTRY_LOCK = RLock()

# Maps ONNX domain names to Python module paths (used for imports only)
_DOMAIN_ALIASES: Dict[str, str] = {
"onnx.brevitas": "qonnx.custom_op.general",
}


def add_domain_alias(domain: str, module_path: str) -> None:
"""Map a domain name to a different module path.

Args:
domain: The ONNX domain name (e.g., "finn.custom_op.fpgadataflow")
module_path: The Python module path to use instead (e.g., "finn_custom_ops.fpgadataflow")
"""
with _REGISTRY_LOCK:
_DOMAIN_ALIASES[domain] = module_path


def resolve_domain(domain: str) -> str:
"""Resolve a domain to its actual module path, handling aliases.

Args:
domain: The ONNX domain name

Returns:
Resolved module path
"""
return _DOMAIN_ALIASES.get(domain, domain)


def _discover_custom_op(domain: str, op_type: str) -> bool:
"""Discover and register a single custom op.

Args:
domain: The ONNX domain name
op_type: The specific op type to discover

Returns:
True if op was found and registered, False otherwise
"""
module_path = resolve_domain(domain)

def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True):
"Return a QONNX CustomOp instance for the given ONNX node, if it exists."
op_type = node.op_type
domain = node.domain
if brevitas_exception:
# transparently resolve Brevitas domain ops to qonnx ones
domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general")
try:
opset_module = importlib.import_module(domain)
assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain
inst_wrapper = opset_module.custom_op[op_type]
inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version)
return inst
module = importlib.import_module(module_path)
except ModuleNotFoundError:
raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain)
except KeyError:
raise Exception("Op %s not found in custom opset %s" % (op_type, domain))
return False

# Try namespace lookup
op_class = getattr(module, op_type, None)
if inspect.isclass(op_class) and issubclass(op_class, CustomOp):
_OP_REGISTRY[(domain, op_type)] = op_class
return True

# Try legacy dict
custom_op_dict = getattr(module, 'custom_op', None)
if isinstance(custom_op_dict, dict):
op_class = custom_op_dict.get(op_type)
if inspect.isclass(op_class) and issubclass(op_class, CustomOp):
_OP_REGISTRY[(domain, op_type)] = op_class
return True

return False


def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()):
"""Get a custom op instance for an ONNX node.

Args:
node: ONNX node with domain and op_type attributes
onnx_opset_version: ONNX opset version to use

Returns:
CustomOp instance for the node

Raises:
KeyError: If op_type not found in domain
"""
op_type = node.op_type
domain = node.domain
key = (domain, op_type)

with _REGISTRY_LOCK:
if key in _OP_REGISTRY:
return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version)

if _discover_custom_op(domain, op_type):
return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version)

module_path = resolve_domain(domain)
raise KeyError(
f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). "
f"Ensure it's exported in the module namespace or in the custom_op dict."
)


def hasCustomOp(domain: str, op_type: str) -> bool:
"""Check if a custom op exists.

Args:
domain: The ONNX domain name
op_type: The operation type name

Returns:
True if the op exists, False otherwise
"""
key = (domain, op_type)

with _REGISTRY_LOCK:
if key in _OP_REGISTRY:
return True
return _discover_custom_op(domain, op_type)
Loading