Skip to content

Commit 2b3a5e4

Browse files
committed
gaussian_noise cvcuda backend
1 parent 98d7dfb commit 2b3a5e4

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

test/test_transforms_v2.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4022,14 +4022,28 @@ def test_kernel_uint8(self, make_input):
40224022

40234023
@pytest.mark.parametrize(
40244024
"make_input",
4025-
[make_image_tensor, make_image, make_video],
4025+
[
4026+
make_image_tensor,
4027+
make_image,
4028+
make_video,
4029+
pytest.param(
4030+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4031+
),
4032+
],
40264033
)
40274034
def test_functional_float(self, make_input):
40284035
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
40294036

40304037
@pytest.mark.parametrize(
40314038
"make_input",
4032-
[make_image_tensor, make_image, make_video],
4039+
[
4040+
make_image_tensor,
4041+
make_image,
4042+
make_video,
4043+
pytest.param(
4044+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4045+
),
4046+
],
40334047
)
40344048
def test_functional_uint8(self, make_input):
40354049
check_functional(F.gaussian_noise, make_input(dtype=torch.uint8))
@@ -4040,14 +4054,28 @@ def test_functional_uint8(self, make_input):
40404054
(F.gaussian_noise, torch.Tensor),
40414055
(F.gaussian_noise_image, tv_tensors.Image),
40424056
(F.gaussian_noise_video, tv_tensors.Video),
4057+
pytest.param(
4058+
F._misc._gaussian_noise_cvcuda,
4059+
"cvcuda.Tensor",
4060+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"),
4061+
),
40434062
],
40444063
)
40454064
def test_functional_signature(self, kernel, input_type):
4065+
if input_type == "cvcuda.Tensor":
4066+
input_type = _import_cvcuda().Tensor
40464067
check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)
40474068

40484069
@pytest.mark.parametrize(
40494070
"make_input",
4050-
[make_image_tensor, make_image, make_video],
4071+
[
4072+
make_image_tensor,
4073+
make_image,
4074+
make_video,
4075+
pytest.param(
4076+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4077+
),
4078+
],
40514079
)
40524080
def test_transform_float(self, make_input):
40534081
def adapter(_, input, __):
@@ -4065,7 +4093,14 @@ def adapter(_, input, __):
40654093

40664094
@pytest.mark.parametrize(
40674095
"make_input",
4068-
[make_image_tensor, make_image, make_video],
4096+
[
4097+
make_image_tensor,
4098+
make_image,
4099+
make_video,
4100+
pytest.param(
4101+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4102+
),
4103+
],
40694104
)
40704105
def test_transform_uint8(self, make_input):
40714106
def adapter(_, input, __):

torchvision/transforms/v2/functional/_misc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,35 @@ def _gaussian_noise_pil(
238238
raise ValueError("Gaussian Noise is not implemented for PIL images.")
239239

240240

241+
def _gaussian_noise_cvcuda(
242+
image: "cvcuda.Tensor",
243+
mean: float = 0.0,
244+
sigma: float = 0.1,
245+
clip: bool = True,
246+
) -> "cvcuda.Tensor":
247+
cvcuda = _import_cvcuda()
248+
249+
batch_size = image.shape[0]
250+
mu_tensor = cvcuda.as_tensor(torch.full((batch_size,), mean, dtype=torch.float32).cuda(), "N")
251+
sigma_tensor = cvcuda.as_tensor(torch.full((batch_size,), sigma, dtype=torch.float32).cuda(), "N")
252+
253+
# per-channel means each channel gets unique random noise, same behavior as torch.randn_like
254+
# produce a seed with torch RNG, if seed is manually set then this will be deterministic
255+
# note: clip is not supported in CV-CUDA, so we don't need to clamp the values
256+
# by default, clamping is done for floats, and uint8 overflows so is clamped from 0-255 anyways
257+
return cvcuda.gaussiannoise(
258+
image,
259+
mu=mu_tensor,
260+
sigma=sigma_tensor,
261+
per_channel=True,
262+
seed=int(torch.empty((), dtype=torch.int64).random_().item()),
263+
)
264+
265+
266+
if CVCUDA_AVAILABLE:
267+
_register_kernel_internal(gaussian_noise, _import_cvcuda().Tensor)(_gaussian_noise_cvcuda)
268+
269+
241270
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
242271
"""See :func:`~torchvision.transforms.v2.ToDtype` for details."""
243272
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)