diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..2c4003f8767 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -1512,6 +1512,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), ], ) def test_functional(self, make_input): @@ -1527,9 +1530,16 @@ def test_functional(self, make_input): (F.affine_mask, tv_tensors.Mask), (F.affine_video, tv_tensors.Video), (F.affine_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._affine_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._geometry._affine_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( @@ -1542,6 +1552,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), ], ) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -1559,8 +1572,19 @@ def test_transform(self, make_input, device): "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] ) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) - def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + def test_functional_image_correctness( + self, angle, translate, scale, shear, center, interpolation, fill, make_input + ): + image = make_input(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -1574,6 +1598,11 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent interpolation=interpolation, fill=fill, ) + + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image( F.affine( F.to_pil_image(image), @@ -1588,7 +1617,11 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent ) mae = (actual.float() - expected.float()).abs().mean() - assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 + if make_input is make_image_cvcuda: + # CV-CUDA nearest interpolation does not follow same algorithm as PIL/torch + assert mae < 255 if interpolation is transforms.InterpolationMode.NEAREST else 1, f"mae: {mae}" + else: + assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8, f"mae: {mae}" @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize( @@ -1596,8 +1629,17 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent ) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, center, interpolation, fill, seed): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + def test_transform_image_correctness(self, center, interpolation, fill, seed, make_input): + image = make_input(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -1608,11 +1650,20 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed): torch.manual_seed(seed) actual = transform(image) + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() + torch.manual_seed(seed) expected = F.to_image(transform(F.to_pil_image(image))) mae = (actual.float() - expected.float()).abs().mean() - assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 + mae = (actual.float() - expected.float()).abs().mean() + if make_input is make_image_cvcuda: + # CV-CUDA nearest interpolation does not follow same algorithm as PIL/torch + assert mae < 255 if interpolation is transforms.InterpolationMode.NEAREST else 1, f"mae: {mae}" + else: + assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8, f"mae: {mae}" def _compute_affine_matrix(self, *, angle, translate, scale, shear, center): rot = math.radians(angle) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 96166e05e9a..b082ab30f18 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -686,6 +686,9 @@ class RandomAffine(Transform): _v1_transform_cls = _transforms.RandomAffine + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, degrees: Union[numbers.Number, Sequence], diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + _is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..0b56f9f56d3 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union +import numpy as np import PIL.Image import torch from torch.nn.functional import grid_sample, interpolate, pad as torch_pad @@ -28,6 +29,7 @@ from ._utils import ( _FillTypeJIT, + _get_cvcuda_interp, _get_kernel, _import_cvcuda, _is_cvcuda_available, @@ -1331,6 +1333,59 @@ def affine_video( ) +def _affine_image_cvcuda( + image: "cvcuda.Tensor", + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: _FillTypeJIT = None, + center: Optional[list[float]] = None, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + interpolation = _check_interpolation(interpolation) + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + height, width, num_channels = image.shape[1:] + + # Determine the actual center point (cx, cy) + # torchvision uses image center by default, cvcuda transforms around upper-left (0,0) + # Unlike the tensor version which uses normalized coordinates centered at image center, + # CV-CUDA uses absolute pixel coordinates, so we pass actual center to _get_inverse_affine_matrix + if center is None: + cx, cy = width / 2.0, height / 2.0 + else: + cx, cy = float(center[0]), float(center[1]) + + translate_f = [float(t) for t in translate] + matrix = _get_inverse_affine_matrix([cx, cy], angle, translate_f, scale, shear) + + interp = _get_cvcuda_interp(interpolation) + + xform = np.array([[matrix[0], matrix[1], matrix[2]], [matrix[3], matrix[4], matrix[5]]], dtype=np.float32) + + if fill is None: + border_value = np.zeros(num_channels, dtype=np.float32) + elif isinstance(fill, (int, float)): + border_value = np.full(num_channels, fill, dtype=np.float32) + else: + border_value = np.array(fill, dtype=np.float32)[:num_channels] + + return cvcuda.warp_affine( + image, + xform, + flags=interp | cvcuda.Interp.WARP_INVERSE_MAP, + border_mode=cvcuda.Border.CONSTANT, + border_value=border_value, + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(affine, _import_cvcuda().Tensor)(_affine_image_cvcuda) + + def rotate( inpt: torch.Tensor, angle: float, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..a1742ba149f 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,9 +1,13 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torchvision import tv_tensors +from torchvision.transforms.functional import InterpolationMode + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -177,3 +181,37 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False + + +_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {} + + +def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp": + if len(_interpolation_mode_to_cvcuda_interp) == 0: + cvcuda = _import_cvcuda() + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS + + interp = _interpolation_mode_to_cvcuda_interp.get(interpolation) + if interp is None: + raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") + + return interp