From a347cf6f6b60896fde0364f9d5ef166b78fbe704 Mon Sep 17 00:00:00 2001 From: Rasmus Haugaard Date: Fri, 15 Mar 2024 09:54:19 +0100 Subject: [PATCH 1/5] added logsumexp tests of edge-cases --- test/composite/test_logsumexp.py | 59 +++++++++++++++++++------------- torch_scatter/testing.py | 5 +++ 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py index 49e7a9c4..48316acc 100644 --- a/test/composite/test_logsumexp.py +++ b/test/composite/test_logsumexp.py @@ -1,31 +1,44 @@ +from itertools import product + +import pytest import torch from torch_scatter import scatter_logsumexp +from torch_scatter.testing import float_dtypes, assert_equal + +tests = [ + [0.5, -2.1, 3.2], + [1e33, 0.5], + [-1e33, 0.5], + [-1e33], + [], + [float("nan"), 0.5], + [float("-inf"), 0.5], + [float("inf"), 0.5], +] + + +@pytest.mark.parametrize('src,dtype', product(tests, float_dtypes)) +def test_logsumexp(src, dtype): + src = torch.tensor(src, dtype=dtype) + index = torch.zeros_like(src, dtype=torch.long) + out_scatter = scatter_logsumexp(src, index, dim_size=1) + out_torch = torch.logsumexp(src, dim=0, keepdim=True) + assert_equal(out_scatter, out_torch, equal_nan=True) + + +def test_logsumexp_parallel_jit(): + splits = [len(src) for src in tests] + srcs = torch.tensor(sum(tests, start=[])) + index = torch.repeat_interleave(torch.tensor(splits)) + srcs.requires_grad_() + outputs = scatter_logsumexp(srcs, index) -def test_logsumexp(): - inputs = torch.tensor([ - 0.5, - 0.5, - 0.0, - -2.1, - 3.2, - 7.0, - -1.0, - -100.0, - ]) - inputs.requires_grad_() - index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4]) - splits = [2, 3, 1, 0, 2] - - outputs = scatter_logsumexp(inputs, index) - - for src, out in zip(inputs.split(splits), outputs.unbind()): - if src.numel() > 0: - assert out.tolist() == torch.logsumexp(src, dim=0).tolist() - else: - assert out.item() == 0.0 + for src, out_scatter in zip(srcs.split(splits), outputs.unbind()): + out_torch = torch.logsumexp(src, dim=0) + assert_equal(out_scatter, out_torch, equal_nan=True) outputs.backward(torch.randn_like(outputs)) jit = torch.jit.script(scatter_logsumexp) - assert jit(inputs, index).tolist() == outputs.tolist() + assert_equal(jit(srcs, index), outputs, equal_nan=True) \ No newline at end of file diff --git a/torch_scatter/testing.py b/torch_scatter/testing.py index 2407b8a0..17569932 100644 --- a/torch_scatter/testing.py +++ b/torch_scatter/testing.py @@ -8,6 +8,7 @@ torch.half, torch.bfloat16, torch.float, torch.double, torch.int, torch.long ] +float_dtypes = list(filter(lambda x: x.is_floating_point, dtypes)) grad_dtypes = [torch.float, torch.double] devices = [torch.device('cpu')] @@ -17,3 +18,7 @@ def tensor(x: Any, dtype: torch.dtype, device: torch.device): return None if x is None else torch.tensor(x, device=device).to(dtype) + + +def assert_equal(actual: torch.Tensor, expected: torch.Tensor, equal_nan=False): + torch.testing.assert_close(actual, expected, equal_nan=equal_nan, rtol=0, atol=0) \ No newline at end of file From 2860e8147058c8cd24bb02be8d69fdc8006715dd Mon Sep 17 00:00:00 2001 From: Rasmus Haugaard Date: Fri, 15 Mar 2024 10:16:23 +0100 Subject: [PATCH 2/5] changed logsumexp to pass new edge-case tests --- torch_scatter/composite/logsumexp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index 355d0c0e..fca75690 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -8,8 +8,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - eps: float = 1e-12) -> torch.Tensor: + dim_size: Optional[int] = None) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') @@ -24,18 +23,19 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, size = list(src.size()) size[dim] = dim_size + max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, device=src.device) - scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] + scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size) + max_value_per_index.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) max_per_src_element = max_value_per_index.gather(dim, index) - recentered_score = src - max_per_src_element - recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf')) - + + src_recentered = src - max_per_src_element if out is not None: out = out.sub_(max_value_per_index).exp_() - sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out, + sum_per_index = scatter_sum(src_recentered.exp_(), index, dim, out, dim_size) - out = sum_per_index.add_(eps).log_().add_(max_value_per_index) - return out.nan_to_num_(neginf=0.0) + return sum_per_index.log_().add_(max_value_per_index) + From 525fb60608af9062c72a3bee09fafd2ae7695a1f Mon Sep 17 00:00:00 2001 From: Rasmus Haugaard Date: Fri, 15 Mar 2024 10:31:55 +0100 Subject: [PATCH 3/5] fixed flake8 complaints --- test/composite/test_logsumexp.py | 2 +- torch_scatter/composite/logsumexp.py | 3 +-- torch_scatter/testing.py | 6 ++++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py index 48316acc..7980df04 100644 --- a/test/composite/test_logsumexp.py +++ b/test/composite/test_logsumexp.py @@ -41,4 +41,4 @@ def test_logsumexp_parallel_jit(): outputs.backward(torch.randn_like(outputs)) jit = torch.jit.script(scatter_logsumexp) - assert_equal(jit(srcs, index), outputs, equal_nan=True) \ No newline at end of file + assert_equal(jit(srcs, index), outputs, equal_nan=True) diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index fca75690..eb4d9c4e 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -29,7 +29,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size) max_value_per_index.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) max_per_src_element = max_value_per_index.gather(dim, index) - + src_recentered = src - max_per_src_element if out is not None: out = out.sub_(max_value_per_index).exp_() @@ -38,4 +38,3 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, dim_size) return sum_per_index.log_().add_(max_value_per_index) - diff --git a/torch_scatter/testing.py b/torch_scatter/testing.py index 17569932..24ad3877 100644 --- a/torch_scatter/testing.py +++ b/torch_scatter/testing.py @@ -20,5 +20,7 @@ def tensor(x: Any, dtype: torch.dtype, device: torch.device): return None if x is None else torch.tensor(x, device=device).to(dtype) -def assert_equal(actual: torch.Tensor, expected: torch.Tensor, equal_nan=False): - torch.testing.assert_close(actual, expected, equal_nan=equal_nan, rtol=0, atol=0) \ No newline at end of file +def assert_equal(actual: torch.Tensor, expected: torch.Tensor, + equal_nan=False): + torch.testing.assert_close(actual, expected, equal_nan=equal_nan, rtol=0, + atol=0) From 97a93d4a38c87a405a1bbf1f42c53c150e7dbd19 Mon Sep 17 00:00:00 2001 From: Rasmus Haugaard Date: Fri, 15 Mar 2024 13:28:48 +0100 Subject: [PATCH 4/5] improved numerical stability of inplace version and added inplace tests --- test/composite/test_logsumexp.py | 21 +++++++++++++++------ torch_scatter/composite/logsumexp.py | 27 +++++++++++++++------------ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py index 7980df04..db5bdccf 100644 --- a/test/composite/test_logsumexp.py +++ b/test/composite/test_logsumexp.py @@ -5,15 +5,13 @@ from torch_scatter import scatter_logsumexp from torch_scatter.testing import float_dtypes, assert_equal +edge_values = [0.0, 1.0, -1e33, 1e33, float("nan"), float("-inf"), + float("inf")] + tests = [ [0.5, -2.1, 3.2], - [1e33, 0.5], - [-1e33, 0.5], - [-1e33], [], - [float("nan"), 0.5], - [float("-inf"), 0.5], - [float("inf"), 0.5], + *map(list, product(edge_values, edge_values)), ] @@ -26,6 +24,17 @@ def test_logsumexp(src, dtype): assert_equal(out_scatter, out_torch, equal_nan=True) +@pytest.mark.parametrize('src,out', product(tests, edge_values)) +def test_logsumexp_inplace(src, out): + src = torch.tensor(src) + out = torch.tensor([out]) + out_scatter = out.clone() + index = torch.zeros_like(src, dtype=torch.long) + scatter_logsumexp(src, index, out=out_scatter) + out_torch = torch.logsumexp(torch.cat([out, src]), dim=0, keepdim=True) + assert_equal(out_scatter, out_torch, equal_nan=True) + + def test_logsumexp_parallel_jit(): splits = [len(src) for src in tests] srcs = torch.tensor(sum(tests, start=[])) diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index eb4d9c4e..e61c9b84 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -15,26 +15,29 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, index = broadcast(index, src, dim) - if out is not None: - dim_size = out.size(dim) - else: - if dim_size is None: + if dim_size is None: + if out is not None: + dim_size = out.size(dim) + else: dim_size = int(index.max()) + 1 + elif out is not None: + assert dim_size == out.size(dim) size = list(src.size()) size[dim] = dim_size - max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, - device=src.device) - scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size) + if out is None: + max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, + device=src.device) + else: + max_value_per_index = out.clone() + scatter_max(src, index, dim, max_value_per_index) max_value_per_index.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) max_per_src_element = max_value_per_index.gather(dim, index) - src_recentered = src - max_per_src_element + src_sub_max = src - max_per_src_element if out is not None: - out = out.sub_(max_value_per_index).exp_() - - sum_per_index = scatter_sum(src_recentered.exp_(), index, dim, out, - dim_size) + out.sub_(max_value_per_index).exp_() + sum_per_index = scatter_sum(src_sub_max.exp_(), index, dim, out, dim_size) return sum_per_index.log_().add_(max_value_per_index) From e0c09aecab58d87ffb45d8475bc6e9be6c57815d Mon Sep 17 00:00:00 2001 From: Rasmus Haugaard Date: Fri, 15 Mar 2024 13:44:22 +0100 Subject: [PATCH 5/5] added test coverage of logsumexp dim_size check --- test/composite/test_logsumexp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py index db5bdccf..a196ed69 100644 --- a/test/composite/test_logsumexp.py +++ b/test/composite/test_logsumexp.py @@ -35,7 +35,7 @@ def test_logsumexp_inplace(src, out): assert_equal(out_scatter, out_torch, equal_nan=True) -def test_logsumexp_parallel_jit(): +def test_logsumexp_parallel_backward_jit(): splits = [len(src) for src in tests] srcs = torch.tensor(sum(tests, start=[])) index = torch.repeat_interleave(torch.tensor(splits)) @@ -51,3 +51,14 @@ def test_logsumexp_parallel_jit(): jit = torch.jit.script(scatter_logsumexp) assert_equal(jit(srcs, index), outputs, equal_nan=True) + + +def test_logsumexp_inplace_dimsize(): + # if both `out` and `dim_size` are provided, they should match + src = torch.zeros(3) + index = src.to(torch.long) + out = torch.zeros(1) + + scatter_logsumexp(src, index, 0, out, dim_size=1) + with pytest.raises(AssertionError): + scatter_logsumexp(src, index, 0, out, dim_size=2)