|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | + |
| 4 | +from pytorch_sparse_utils.validation import validate_nd, validate_dim_size, validate_atleast_nd |
| 5 | + |
| 6 | +@pytest.mark.cpu_and_cuda |
| 7 | +class TestValidate: |
| 8 | + def test_validate_nd(self, device): |
| 9 | + tensor = torch.randn(4, 5, 6, device=device) |
| 10 | + validate_nd(tensor, 3) |
| 11 | + with pytest.raises( |
| 12 | + (ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 13 | + match="Expected tensor to be 4D" |
| 14 | + ): |
| 15 | + validate_nd(tensor, 4) |
| 16 | + |
| 17 | + def test_validate_at_least_nd(self, device): |
| 18 | + tensor = torch.randn(4, 5, 6, device=device) |
| 19 | + validate_atleast_nd(tensor, 3) |
| 20 | + with pytest.raises( |
| 21 | + (ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 22 | + match="Expected tensor to have at least" |
| 23 | + ): |
| 24 | + validate_atleast_nd(tensor, 4) |
| 25 | + |
| 26 | + def test_validate_dim_size(self, device): |
| 27 | + tensor = torch.randn(3, 4, 5, device=device) |
| 28 | + validate_dim_size(tensor, dim=0, expected_size=3) |
| 29 | + with pytest.raises( |
| 30 | + (ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 31 | + match=r"Expected tensor to have shape\[0\]=4" |
| 32 | + ): |
| 33 | + validate_dim_size(tensor, dim=0, expected_size=4) |
0 commit comments