Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mojo_opset/core/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import os

from abc import ABC
Expand Down Expand Up @@ -67,10 +68,50 @@ def __init__(self, **kwargs):
ABC.__init__(self)
self.tensor_factory_kwargs = get_tensor_factory_kwargs(**kwargs)

def _call_impl(self, *args, **kwargs):
if torch.autograd._profiler_enabled():
sig = self._get_forward_signature_repr()
shape_repr = self._build_mojo_shape_repr(args, kwargs)
name = f"{type(self).__name__}{sig}[{shape_repr}]"
with torch.profiler.record_function(name):
return super()._call_impl(*args, **kwargs)
return super()._call_impl(*args, **kwargs)
Comment on lines +71 to +78

@abstractmethod
def forward(self, *args, **kwargs) -> Tuple[Any]:
raise NotImplementedError

@classmethod
def _get_forward_signature_repr(cls) -> str:
cached = cls.__dict__.get("_forward_signature_repr_cache")
if cached is not None:
return cached
try:
sig = inspect.signature(cls.forward)
params = [p for p in sig.parameters.values() if p.name != "self"]
sig_no_self = sig.replace(parameters=params)
repr_str = str(sig_no_self)
except (TypeError, ValueError):
repr_str = "(...)"
cls._forward_signature_repr_cache = repr_str
return repr_str

@staticmethod
def _format_tensor_meta(t: torch.Tensor) -> str:
dtype_str = str(t.dtype).replace("torch.", "")
shape = ",".join(str(s) for s in t.shape)
return f"({shape}){dtype_str}"
Comment on lines +99 to +103

def _build_mojo_shape_repr(self, args, kwargs) -> str:
parts = []
for a in args:
if isinstance(a, torch.Tensor):
parts.append(self._format_tensor_meta(a))
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
parts.append(f"{k}={self._format_tensor_meta(v)}")
return ",".join(parts)
Comment on lines +105 to +113

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _build_mojo_shape_repr only inspects direct torch.Tensor arguments. However, many operators accept lists or tuples of tensors (e.g., Concat, Stack, or custom multi-input/multi-output layers). In those cases, the tensor shapes and dtypes will be completely omitted from the profiler trace.

We can make this more robust by recursively formatting elements within lists and tuples so that their shapes and dtypes are also captured in the profiler traces.

    def _build_mojo_shape_repr(self, args, kwargs) -> str:
        def _format_item(item) -> Optional[str]:
            if isinstance(item, torch.Tensor):
                return self._format_tensor_meta(item)
            elif isinstance(item, (list, tuple)):
                inner = [_format_item(x) for x in item]
                inner = [x for x in inner if x is not None]
                if inner:
                    bracket_open, bracket_close = ("[", "]") if isinstance(item, list) else ("(", ")")
                    return f"{bracket_open}{','.join(inner)}{bracket_close}"
            return None

        parts = []
        for a in args:
            formatted = _format_item(a)
            if formatted is not None:
                parts.append(formatted)
        for k, v in kwargs.items():
            formatted = _format_item(v)
            if formatted is not None:
                parts.append(f"{k}={formatted}")
        return ",".join(parts)


def forward_diff_with(
self,
other_op: "MojoOperator",
Expand Down
Loading