diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 7566ca07..11b1b68b 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -39,6 +39,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -619,11 +620,11 @@ def get_nodes_by_op_type(self, op_type): def get_finn_nodes(self): """Returns a list of nodes where domain == 'qonnx.*'.""" - return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node)) + return list(filter(lambda x: is_custom_op(x.domain), self.graph.node)) def get_non_finn_nodes(self): """Returns a list of nodes where domain != 'qonnx.*'.""" - return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node)) + return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node)) def get_node_index(self, node): """Returns current index of given node, or None if not found.""" diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a8f4774c..3a686f7e 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -35,10 +35,10 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node +from qonnx.custom_op.registry import is_custom_op from qonnx.util.basic import ( get_preferred_onnx_opset, get_sanitize_quant_tensors, - is_finn_op, qonnx_make_model, sanitize_quant_values, ) @@ -49,7 +49,7 @@ def execute_node(node, context, graph, return_full_exec_context=False, opset_ver Input/output provided via context.""" - if is_finn_op(node.domain): + if is_custom_op(node.domain, node.op_type): ex_cu_node.execute_custom_node(node, context, graph, onnx_opset_version=opset_version) else: # onnxruntime unfortunately does not implement run_node as defined by ONNX, diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..77a048e7 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -1,9 +1,11 @@ +# Importing registers CustomOps in qonnx.custom_op.channels_last domain from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -custom_op = dict() - -custom_op["Conv"] = Conv -custom_op["MaxPool"] = MaxPool -custom_op["BatchNormalization"] = BatchNormalization +# Legacy dictionary for backward compatibility +custom_op = { + "Conv": Conv, + "MaxPool": MaxPool, + "BatchNormalization": BatchNormalization, +} \ No newline at end of file diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 9b14ea8a..2f3896de 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Importing registers CustomOps in qonnx.custom_op.general domain from qonnx.custom_op.general.bipolar_quant import BipolarQuant from qonnx.custom_op.general.debugmarker import DebugMarker from qonnx.custom_op.general.floatquant import FloatQuant @@ -35,20 +36,22 @@ from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d +from qonnx.custom_op.general.quant import Quant from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul -custom_op = dict() - -custom_op["DebugMarker"] = DebugMarker -custom_op["QuantAvgPool2d"] = QuantAvgPool2d -custom_op["MaxPoolNHWC"] = MaxPoolNHWC -custom_op["GenericPartition"] = GenericPartition -custom_op["MultiThreshold"] = MultiThreshold -custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul -custom_op["Im2Col"] = Im2Col -custom_op["IntQuant"] = IntQuant -custom_op["Quant"] = IntQuant -custom_op["Trunc"] = Trunc -custom_op["BipolarQuant"] = BipolarQuant -custom_op["FloatQuant"] = FloatQuant +# Legacy dictionary for backward compatibility +custom_op = { + "DebugMarker": DebugMarker, + "QuantAvgPool2d": QuantAvgPool2d, + "MaxPoolNHWC": MaxPoolNHWC, + "GenericPartition": GenericPartition, + "MultiThreshold": MultiThreshold, + "XnorPopcountMatMul": XnorPopcountMatMul, + "Im2Col": Im2Col, + "IntQuant": IntQuant, + "Quant": IntQuant, # Alias + "Trunc": Trunc, + "BipolarQuant": BipolarQuant, + "FloatQuant": FloatQuant, +} \ No newline at end of file diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..8eb1b378 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -27,24 +27,234 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import importlib +import inspect +from threading import RLock +from typing import Dict, List, Optional, Tuple, Type +from qonnx.custom_op.base import CustomOp from qonnx.util.basic import get_preferred_onnx_opset +# Registry keyed by original ONNX domain: (domain, op_type) -> CustomOp class +_OP_REGISTRY: Dict[Tuple[str, str], Type[CustomOp]] = {} + +_REGISTRY_LOCK = RLock() + +# Maps ONNX domain names to Python module paths (used for imports only) +_DOMAIN_ALIASES: Dict[str, str] = { + "onnx.brevitas": "qonnx.custom_op.general", +} + + +def add_domain_alias(domain: str, module_path: str) -> None: + """Map a domain name to a different module path. + + Args: + domain: The ONNX domain name (e.g., "finn.custom_op.fpgadataflow") + module_path: The Python module path to use instead (e.g., "finn_custom_ops.fpgadataflow") + """ + with _REGISTRY_LOCK: + _DOMAIN_ALIASES[domain] = module_path + + +def resolve_domain(domain: str) -> str: + """Resolve a domain to its actual module path, handling aliases. + + Args: + domain: The ONNX domain name + + Returns: + Resolved module path + """ + return _DOMAIN_ALIASES.get(domain, domain) + + +def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None: + """Register a custom op directly to a domain at runtime. + + The op_type is automatically derived from the class name. + Useful for testing and experimentation. For production, define CustomOps + in the appropriate module file. + + Args: + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + op_class: CustomOp subclass + + Example: + add_op_to_domain("qonnx.custom_op.general", MyTestOp) + """ + if not inspect.isclass(op_class) or not issubclass(op_class, CustomOp): + raise ValueError(f"{op_class} must be a subclass of CustomOp") + + op_type = op_class.__name__ + + with _REGISTRY_LOCK: + _OP_REGISTRY[(domain, op_type)] = op_class + + +def _discover_custom_op(domain: str, op_type: str) -> bool: + """Discover and register a single custom op. + + Args: + domain: The ONNX domain name + op_type: The specific op type to discover + + Returns: + True if op was found and registered, False otherwise + """ + module_path = resolve_domain(domain) -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - "Return a QONNX CustomOp instance for the given ONNX node, if it exists." - op_type = node.op_type - domain = node.domain - if brevitas_exception: - # transparently resolve Brevitas domain ops to qonnx ones - domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") try: - opset_module = importlib.import_module(domain) - assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - inst_wrapper = opset_module.custom_op[op_type] - inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) - return inst + module = importlib.import_module(module_path) except ModuleNotFoundError: - raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) - except KeyError: - raise Exception("Op %s not found in custom opset %s" % (op_type, domain)) + return False + + # Try namespace lookup + op_class = getattr(module, op_type, None) + if inspect.isclass(op_class) and issubclass(op_class, CustomOp): + _OP_REGISTRY[(domain, op_type)] = op_class + return True + + # Try legacy dict + custom_op_dict = getattr(module, 'custom_op', None) + if isinstance(custom_op_dict, dict): + op_class = custom_op_dict.get(op_type) + if inspect.isclass(op_class) and issubclass(op_class, CustomOp): + _OP_REGISTRY[(domain, op_type)] = op_class + return True + + return False + + +def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()): + """Get a custom op instance for an ONNX node. + + Args: + node: ONNX node with domain and op_type attributes + onnx_opset_version: ONNX opset version to use + + Returns: + CustomOp instance for the node + + Raises: + KeyError: If op_type not found in domain + """ + op_type = node.op_type + domain = node.domain + key = (domain, op_type) + + with _REGISTRY_LOCK: + if key in _OP_REGISTRY: + return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + + if _discover_custom_op(domain, op_type): + return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + + module_path = resolve_domain(domain) + raise KeyError( + f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " + f"Ensure it's exported in the module namespace or in the custom_op dict." + ) + + +def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool: + """Check if a custom op exists or if a domain has any custom ops. + + Args: + domain: The ONNX domain name + op_type: Optional operation type name. If None, checks if domain has any ops. + + Returns: + True if the specific op exists (when op_type given) or + if any ops exist for the domain (when op_type=None), False otherwise + """ + # Empty domain means standard ONNX op + if not domain: + return False + + with _REGISTRY_LOCK: + if op_type is not None: + # Check for specific op + key = (domain, op_type) + if key in _OP_REGISTRY: + return True + return _discover_custom_op(domain, op_type) + else: + # Check if domain has any registered ops + if any(d == domain for d, _ in _OP_REGISTRY.keys()): + return True + # Try to import the domain module as fallback + module_path = resolve_domain(domain) + try: + importlib.import_module(module_path) + return True + except (ModuleNotFoundError, ValueError): + return False + + +def hasCustomOp(domain: str, op_type: str) -> bool: + """Deprecated: Use is_custom_op instead. + + Check if a custom op exists. + + Args: + domain: The ONNX domain name + op_type: The operation type name + + Returns: + True if the op exists, False otherwise + """ + import warnings + warnings.warn( + "hasCustomOp is deprecated and will be removed in QONNX v1.0. " + "Use is_custom_op instead.", + DeprecationWarning, + stacklevel=2 + ) + return is_custom_op(domain, op_type) + + +def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]: + """Get all CustomOp classes available in a domain. + + Args: + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + + Returns: + List of (op_type, op_class) tuples + + Example: + ops = get_ops_in_domain("qonnx.custom_op.general") + for op_name, op_class in ops: + print(f"{op_name}: {op_class}") + """ + ops = [] + module_path = resolve_domain(domain) + + with _REGISTRY_LOCK: + # Strategy 1: Get cached ops (fast path) + for (d, op_type), op_class in _OP_REGISTRY.items(): + if d == domain: + ops.append((op_type, op_class)) + + # Strategy 2: Discover from module (for uncached ops) + try: + module = importlib.import_module(module_path) + + # Check namespace exports + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_') and + not any(op[0] == name for op in ops)): + ops.append((name, obj)) + + # Check legacy custom_op dict + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + for name, cls in module.custom_op.items(): + if not any(op[0] == name for op in ops): + ops.append((name, cls)) + except ModuleNotFoundError: + pass # Domain doesn't exist as module, return cached ops only + + return ops diff --git a/src/qonnx/transformation/infer_data_layouts.py b/src/qonnx/transformation/infer_data_layouts.py index 81143e45..2e23d771 100644 --- a/src/qonnx/transformation/infer_data_layouts.py +++ b/src/qonnx/transformation/infer_data_layouts.py @@ -30,15 +30,16 @@ import qonnx.core.data_layout as DataLayout import qonnx.custom_op.registry as registry +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name, is_finn_op +from qonnx.util.basic import get_by_name def _dims_to_layout(model, node, ndims): if ndims == 2: return DataLayout.NC else: - if is_finn_op(node.domain): + if is_custom_op(node.domain): if node.op_type == "MultiThreshold" or node.op_type == "QuantAvgPool2d": mt_inst = registry.getCustomOp(node) layout = mt_inst.get_nodeattr("data_layout") @@ -72,7 +73,7 @@ def _infer_node_data_layout(model, node): Returns True if any changes were made.""" old_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output)) try: - if is_finn_op(node.domain): + if is_custom_op(node.domain): # try to guess based on number of output dims for o in node.output: ndims = len(model.get_tensor_shape(o)) diff --git a/src/qonnx/transformation/infer_datatypes.py b/src/qonnx/transformation/infer_datatypes.py index d54fd34f..167e0c3e 100644 --- a/src/qonnx/transformation/infer_datatypes.py +++ b/src/qonnx/transformation/infer_datatypes.py @@ -28,9 +28,10 @@ import qonnx.custom_op.registry as registry from qonnx.core.datatype import DataType, ScaledIntType +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation from qonnx.transformation.qcdq_to_qonnx import extract_elem_type -from qonnx.util.basic import get_by_name, is_finn_op +from qonnx.util.basic import get_by_name def is_scaled_int(x): @@ -82,7 +83,7 @@ def _infer_node_datatype(model, node, allow_scaledint_dtypes): idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) op_type = node.op_type - if is_finn_op(node.domain): + if is_custom_op(node.domain): # handle DataType inference for CustomOp try: # lookup op_type in registry of CustomOps diff --git a/src/qonnx/transformation/infer_shapes.py b/src/qonnx/transformation/infer_shapes.py index 87fbf0ee..3e532abf 100644 --- a/src/qonnx/transformation/infer_shapes.py +++ b/src/qonnx/transformation/infer_shapes.py @@ -30,14 +30,14 @@ import qonnx.custom_op.registry as registry from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation -from qonnx.util.basic import is_finn_op def _make_shape_compatible_op(node, model): """Return a shape-compatible non-QONNX op for a given QONNX op. Used for shape inference with custom ops.""" - assert is_finn_op(node.domain), "Node domain is not set to qonnx.*" + assert is_custom_op(node.domain), "Node domain is not a registered custom op domain" op_type = node.op_type try: # lookup op_type in registry of CustomOps @@ -56,7 +56,7 @@ def _hide_finn_ops(model): node_ind = 0 for node in model.graph.node: node_ind += 1 - if is_finn_op(node.domain): + if is_custom_op(node.domain): new_node = _make_shape_compatible_op(node, model) # keep old node name to help debug shape inference issues new_node.name = node.name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 3a3ce2af..4e300dd1 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -63,8 +63,21 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): - "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") + """Deprecated: Use is_custom_op from qonnx.custom_op.registry instead. + + Return whether given op_type string is a QONNX or FINN custom op. + This function uses hard-coded string matching and will be removed in QONNX v1.0. + Use the registry-based is_custom_op for better accuracy and extensibility. + """ + import warnings + warnings.warn( + "is_finn_op is deprecated and will be removed in QONNX v1.0. " + "Use 'from qonnx.custom_op.registry import is_custom_op' instead.", + DeprecationWarning, + stacklevel=2 + ) + from qonnx.custom_op.registry import is_custom_op + return is_custom_op(op_type) def get_num_default_workers():