diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 1ba47f60a36..2fd9b9fbc7d 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,3 +1,5 @@ +from typing import TypeVar, cast + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat @@ -7,12 +9,14 @@ from ._tv_tensor import TVTensor from ._video import Video +TVTensorLike = TypeVar("TVTensorLike", bound=TVTensor, covariant=True) + # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike: # type: ignore """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of @@ -26,10 +30,10 @@ def wrap(wrappee, *, like, **kwargs): Ignored otherwise. """ if isinstance(like, BoundingBoxes): - return BoundingBoxes._wrap( + return cast(TVTensorLike, BoundingBoxes._wrap( wrappee, format=kwargs.get("format", like.format), canvas_size=kwargs.get("canvas_size", like.canvas_size), - ) + )) else: return wrappee.as_subclass(type(like))