From 8589e2ebe241cd1231d52d69d6156b04d028a22a Mon Sep 17 00:00:00 2001 From: Booth Li <1260144395@qq.com> Date: Fri, 12 Jun 2026 07:44:19 +0000 Subject: [PATCH] Wrap MojoOperator forward calls with torch.profiler.record_function so each op shows up in profiler traces with its class name, forward signature, and input tensor shape/dtype metadata. Adds helpers _get_forward_signature_repr (cached per class), _format_tensor_meta, and _build_mojo_shape_repr to build the record name. --- mojo_opset/core/operator.py | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/mojo_opset/core/operator.py b/mojo_opset/core/operator.py index fc99bc166..d91409811 100644 --- a/mojo_opset/core/operator.py +++ b/mojo_opset/core/operator.py @@ -1,3 +1,4 @@ +import inspect import os from abc import ABC @@ -67,10 +68,50 @@ def __init__(self, **kwargs): ABC.__init__(self) self.tensor_factory_kwargs = get_tensor_factory_kwargs(**kwargs) + def _call_impl(self, *args, **kwargs): + if torch.autograd._profiler_enabled(): + sig = self._get_forward_signature_repr() + shape_repr = self._build_mojo_shape_repr(args, kwargs) + name = f"{type(self).__name__}{sig}[{shape_repr}]" + with torch.profiler.record_function(name): + return super()._call_impl(*args, **kwargs) + return super()._call_impl(*args, **kwargs) + @abstractmethod def forward(self, *args, **kwargs) -> Tuple[Any]: raise NotImplementedError + @classmethod + def _get_forward_signature_repr(cls) -> str: + cached = cls.__dict__.get("_forward_signature_repr_cache") + if cached is not None: + return cached + try: + sig = inspect.signature(cls.forward) + params = [p for p in sig.parameters.values() if p.name != "self"] + sig_no_self = sig.replace(parameters=params) + repr_str = str(sig_no_self) + except (TypeError, ValueError): + repr_str = "(...)" + cls._forward_signature_repr_cache = repr_str + return repr_str + + @staticmethod + def _format_tensor_meta(t: torch.Tensor) -> str: + dtype_str = str(t.dtype).replace("torch.", "") + shape = ",".join(str(s) for s in t.shape) + return f"({shape}){dtype_str}" + + def _build_mojo_shape_repr(self, args, kwargs) -> str: + parts = [] + for a in args: + if isinstance(a, torch.Tensor): + parts.append(self._format_tensor_meta(a)) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + parts.append(f"{k}={self._format_tensor_meta(v)}") + return ",".join(parts) + def forward_diff_with( self, other_op: "MojoOperator",