Skip to content

Commit 444d02f

Browse files
authored
Add similarity calculation for all pairs of vectors in the unit test for the circular function (#175)
1 parent 54b6465 commit 444d02f

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,18 @@ def test_value(self, dtype, vsa):
127127
)
128128
else:
129129
hv = functional.circular(8, 1000000, vsa, generator=generator, dtype=dtype)
130-
sims = functional.cosine_similarity(hv[0], hv)
131-
sims_diff = sims[:-1] - sims[1:]
132-
assert torch.all(
133-
sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1])
134-
), "second half must get more similar"
135-
136-
assert torch.allclose(
137-
sims_diff.abs(), torch.tensor(0.25, dtype=sims_diff.dtype), atol=0.005
138-
), "similarity decreases linearly"
130+
131+
for i in range(8-1):
132+
sims = functional.cosine_similarity(hv[0], hv)
133+
sims_diff = sims[:-1] - sims[1:]
134+
assert torch.all(
135+
sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1])
136+
), f"element #{i}: second half must get more similar"
137+
138+
assert torch.allclose(
139+
sims_diff.abs(), torch.tensor(0.25, dtype=sims_diff.dtype), atol=0.005
140+
), f"element #{i}: similarity decreases linearly"
141+
hv = torch.roll(hv,1,0)
139142

140143
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0])
141144
@pytest.mark.parametrize("dtype", torch_dtypes)

0 commit comments

Comments
 (0)