diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..5d87646eba3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -4074,14 +4074,28 @@ def test_kernel_uint8(self, make_input): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_functional_float(self, make_input): check_functional(F.gaussian_noise, make_input(dtype=torch.float32)) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_functional_uint8(self, make_input): check_functional(F.gaussian_noise, make_input(dtype=torch.uint8)) @@ -4092,14 +4106,28 @@ def test_functional_uint8(self, make_input): (F.gaussian_noise, torch.Tensor), (F.gaussian_noise_image, tv_tensors.Image), (F.gaussian_noise_video, tv_tensors.Video), + pytest.param( + F._misc._gaussian_noise_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._misc._gaussian_noise_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_transform_float(self, make_input): def adapter(_, input, __): @@ -4117,7 +4145,14 @@ def adapter(_, input, __): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_transform_uint8(self, make_input): def adapter(_, input, __): diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..dfb2a4721c6 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,6 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from ._utils import ( _parse_labels_getter, @@ -20,6 +21,8 @@ is_pure_tensor, ) +CVCUDA_AVAILABLE = _is_cvcuda_available() + # TODO: do we want/need to expose this? class Identity(Transform): @@ -240,6 +243,9 @@ class GaussianNoise(Transform): Default is True. """ + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: super().__init__() self.mean = mean diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + _is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..483330de7b3 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( @@ -231,6 +238,35 @@ def _gaussian_noise_pil( raise ValueError("Gaussian Noise is not implemented for PIL images.") +def _gaussian_noise_image_cvcuda( + image: "cvcuda.Tensor", + mean: float = 0.0, + sigma: float = 0.1, + clip: bool = True, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + batch_size = image.shape[0] + mu_tensor = cvcuda.as_tensor(torch.full((batch_size,), mean, dtype=torch.float32).cuda(), "N") + sigma_tensor = cvcuda.as_tensor(torch.full((batch_size,), sigma, dtype=torch.float32).cuda(), "N") + + # per-channel means each channel gets unique random noise, same behavior as torch.randn_like + # produce a seed with torch RNG, if seed is manually set then this will be deterministic + # note: clip is not supported in CV-CUDA, so we don't need to clamp the values + # by default, clamping is done for floats, and uint8 overflows so is clamped from 0-255 anyways + return cvcuda.gaussiannoise( + image, + mu=mu_tensor, + sigma=sigma_tensor, + per_channel=True, + seed=int(torch.empty((), dtype=torch.int64).random_().item()), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(gaussian_noise, _import_cvcuda().Tensor)(_gaussian_noise_image_cvcuda) + + def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ToDtype` for details.""" if torch.jit.is_scripting():