Skip to content

Commit 1860d73

Browse files
committed
update to main standards
1 parent 982d21d commit 1860d73

File tree

7 files changed

+64
-69
lines changed

7 files changed

+64
-69
lines changed

test/test_transforms_v2.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -1532,14 +1531,14 @@ def test_functional(self, make_input):
15321531
(F.affine_video, tv_tensors.Video),
15331532
(F.affine_keypoints, tv_tensors.KeyPoints),
15341533
pytest.param(
1535-
F._geometry._affine_cvcuda,
1536-
"cvcuda.Tensor",
1534+
F._geometry._affine_image_cvcuda,
1535+
None,
15371536
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
15381537
),
15391538
],
15401539
)
15411540
def test_functional_signature(self, kernel, input_type):
1542-
if input_type == "cvcuda.Tensor":
1541+
if kernel is F._geometry._affine_image_cvcuda:
15431542
input_type = _import_cvcuda().Tensor
15441543
check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
15451544

@@ -1601,8 +1600,8 @@ def test_functional_image_correctness(
16011600
)
16021601

16031602
if make_input is make_image_cvcuda:
1604-
actual = cvcuda_to_pil_compatible_tensor(actual)
1605-
image = cvcuda_to_pil_compatible_tensor(image)
1603+
actual = F.cvcuda_to_tensor(actual)[0].cpu()
1604+
image = F.cvcuda_to_tensor(image)[0].cpu()
16061605

16071606
expected = F.to_image(
16081607
F.affine(
@@ -1652,8 +1651,8 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed, ma
16521651
actual = transform(image)
16531652

16541653
if make_input is make_image_cvcuda:
1655-
actual = cvcuda_to_pil_compatible_tensor(actual)
1656-
image = cvcuda_to_pil_compatible_tensor(image)
1654+
actual = F.cvcuda_to_tensor(actual)[0].cpu()
1655+
image = F.cvcuda_to_tensor(image)[0].cpu()
16571656

16581657
torch.manual_seed(seed)
16591658
expected = F.to_image(transform(F.to_pil_image(image)))

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,8 @@ class RandomAffine(Transform):
686686

687687
_v1_transform_cls = _transforms.RandomAffine
688688

689-
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
689+
if CVCUDA_AVAILABLE:
690+
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)
690691

691692
def __init__(
692693
self,

torchvision/transforms/v2/functional/_augment.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
from typing import TYPE_CHECKING
32

43
import PIL.Image
54

@@ -9,15 +8,7 @@
98
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
109
from torchvision.utils import _log_api_usage_once
1110

12-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
13-
14-
15-
CVCUDA_AVAILABLE = _is_cvcuda_available()
16-
17-
if TYPE_CHECKING:
18-
import cvcuda # type: ignore[import-not-found]
19-
if CVCUDA_AVAILABLE:
20-
cvcuda = _import_cvcuda() # noqa: F811
11+
from ._utils import _get_kernel, _register_kernel_internal
2112

2213

2314
def erase(

torchvision/transforms/v2/functional/_color.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import TYPE_CHECKING
2-
31
import PIL.Image
42
import torch
53
from torch.nn.functional import conv2d
@@ -11,15 +9,7 @@
119

1210
from ._misc import _num_value_bits, to_dtype_image
1311
from ._type_conversion import pil_to_tensor, to_pil_image
14-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
15-
16-
17-
CVCUDA_AVAILABLE = _is_cvcuda_available()
18-
19-
if TYPE_CHECKING:
20-
import cvcuda # type: ignore[import-not-found]
21-
if CVCUDA_AVAILABLE:
22-
cvcuda = _import_cvcuda() # noqa: F811
12+
from ._utils import _get_kernel, _register_kernel_internal
2313

2414

2515
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from ._utils import (
3131
_FillTypeJIT,
32+
_get_cvcuda_interp,
3233
_get_kernel,
3334
_import_cvcuda,
3435
_is_cvcuda_available,
@@ -1332,31 +1333,7 @@ def affine_video(
13321333
)
13331334

13341335

1335-
if CVCUDA_AVAILABLE:
1336-
_cvcuda_interp = {
1337-
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
1338-
"bilinear": cvcuda.Interp.LINEAR,
1339-
"linear": cvcuda.Interp.LINEAR,
1340-
2: cvcuda.Interp.LINEAR,
1341-
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
1342-
"bicubic": cvcuda.Interp.CUBIC,
1343-
3: cvcuda.Interp.CUBIC,
1344-
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
1345-
"nearest": cvcuda.Interp.NEAREST,
1346-
0: cvcuda.Interp.NEAREST,
1347-
InterpolationMode.BOX: cvcuda.Interp.BOX,
1348-
"box": cvcuda.Interp.BOX,
1349-
4: cvcuda.Interp.BOX,
1350-
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
1351-
"hamming": cvcuda.Interp.HAMMING,
1352-
5: cvcuda.Interp.HAMMING,
1353-
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
1354-
"lanczos": cvcuda.Interp.LANCZOS,
1355-
1: cvcuda.Interp.LANCZOS,
1356-
}
1357-
1358-
1359-
def _affine_cvcuda(
1336+
def _affine_image_cvcuda(
13601337
image: "cvcuda.Tensor",
13611338
angle: Union[int, float],
13621339
translate: list[float],
@@ -1385,9 +1362,7 @@ def _affine_cvcuda(
13851362
translate_f = [float(t) for t in translate]
13861363
matrix = _get_inverse_affine_matrix([cx, cy], angle, translate_f, scale, shear)
13871364

1388-
interp = _cvcuda_interp.get(interpolation)
1389-
if interp is None:
1390-
raise ValueError(f"Invalid interpolation mode: {interpolation}")
1365+
interp = _get_cvcuda_interp(interpolation)
13911366

13921367
xform = np.array([[matrix[0], matrix[1], matrix[2]], [matrix[3], matrix[4], matrix[5]]], dtype=np.float32)
13931368

@@ -1408,7 +1383,7 @@ def _affine_cvcuda(
14081383

14091384

14101385
if CVCUDA_AVAILABLE:
1411-
_register_kernel_internal(affine, _import_cvcuda().Tensor)(_affine_cvcuda)
1386+
_register_kernel_internal(affine, _import_cvcuda().Tensor)(_affine_image_cvcuda)
14121387

14131388

14141389
def rotate(

torchvision/transforms/v2/functional/_misc.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, TYPE_CHECKING
2+
from typing import Optional
33

44
import PIL.Image
55
import torch
@@ -13,14 +13,7 @@
1313

1414
from ._meta import _convert_bounding_box_format
1515

16-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor
17-
18-
CVCUDA_AVAILABLE = _is_cvcuda_available()
19-
20-
if TYPE_CHECKING:
21-
import cvcuda # type: ignore[import-not-found]
22-
if CVCUDA_AVAILABLE:
23-
cvcuda = _import_cvcuda() # noqa: F811
16+
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
2417

2518

2619
def normalize(

torchvision/transforms/v2/functional/_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import functools
22
from collections.abc import Sequence
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
44

55
import torch
66
from torchvision import tv_tensors
7+
from torchvision.transforms.functional import InterpolationMode
8+
9+
if TYPE_CHECKING:
10+
import cvcuda # type: ignore[import-not-found]
711

812
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
913
_FillTypeJIT = Optional[list[float]]
@@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
177181
return isinstance(inpt, cvcuda.Tensor)
178182
except ImportError:
179183
return False
184+
185+
186+
_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {}
187+
188+
189+
def _populate_interpolation_mode_to_cvcuda_interp():
190+
cvcuda = _import_cvcuda()
191+
192+
global _interpolation_mode_to_cvcuda_interp
193+
194+
_interpolation_mode_to_cvcuda_interp = {
195+
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
196+
"bilinear": cvcuda.Interp.LINEAR,
197+
"linear": cvcuda.Interp.LINEAR,
198+
2: cvcuda.Interp.LINEAR,
199+
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
200+
"bicubic": cvcuda.Interp.CUBIC,
201+
3: cvcuda.Interp.CUBIC,
202+
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
203+
"nearest": cvcuda.Interp.NEAREST,
204+
0: cvcuda.Interp.NEAREST,
205+
InterpolationMode.BOX: cvcuda.Interp.BOX,
206+
"box": cvcuda.Interp.BOX,
207+
4: cvcuda.Interp.BOX,
208+
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
209+
"hamming": cvcuda.Interp.HAMMING,
210+
5: cvcuda.Interp.HAMMING,
211+
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
212+
"lanczos": cvcuda.Interp.LANCZOS,
213+
1: cvcuda.Interp.LANCZOS,
214+
}
215+
216+
217+
def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp":
218+
if len(_interpolation_mode_to_cvcuda_interp) == 0:
219+
_populate_interpolation_mode_to_cvcuda_interp()
220+
221+
interp = _interpolation_mode_to_cvcuda_interp.get(interpolation)
222+
if interp is None:
223+
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")
224+
225+
return interp

0 commit comments

Comments
 (0)