Skip to content

Commit bdffc21

Browse files
Add normalization function of hypervectors and deprecate hard_quantize (#173)
* Add normalization function of hypervectors and deprecate hard_quantize * [github-action] formatting fixes * Test newer python versions --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent c23e945 commit bdffc21

File tree

14 files changed

+272
-21
lines changed

14 files changed

+272
-21
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
timeout-minutes: 20
1616
strategy:
1717
matrix:
18-
python-version: ['3.8', '3.9', '3.10']
18+
python-version: ['3.10', '3.11', '3.12']
1919
os: [ubuntu-latest, windows-latest, macos-latest]
2020

2121
steps:

docs/torchhd.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Operations
3434
permute
3535
inverse
3636
negative
37+
normalize
3738
cleanup
3839
randsel
3940
multirandsel

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
permute,
5252
inverse,
5353
negative,
54+
normalize,
5455
cleanup,
5556
create_random_permute,
5657
randsel,
@@ -109,6 +110,7 @@
109110
"permute",
110111
"inverse",
111112
"negative",
113+
"normalize",
112114
"cleanup",
113115
"create_random_permute",
114116
"randsel",

torchhd/functional.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
from torch import LongTensor, FloatTensor, Tensor
2828
from collections import deque
29+
import warnings
2930

3031
from torchhd.tensors.base import VSATensor
3132
from torchhd.tensors.bsc import BSCTensor
@@ -50,6 +51,7 @@
5051
"permute",
5152
"inverse",
5253
"negative",
54+
"normalize",
5355
"cleanup",
5456
"create_random_permute",
5557
"hard_quantize",
@@ -673,6 +675,11 @@ def bundle(input: VSATensor, other: VSATensor) -> VSATensor:
673675
674676
\oplus: \mathcal{H} \times \mathcal{H} \to \mathcal{H}
675677
678+
.. note::
679+
680+
This operation does not normalize the resulting hypervectors.
681+
Normalized hypervectors can be obtained with :func:`~torchhd.normalize`.
682+
676683
Args:
677684
input (VSATensor): input hypervector
678685
other (VSATensor): other input hypervector
@@ -885,6 +892,12 @@ def hard_quantize(input: Tensor):
885892
tensor([ 1., -1., -1., -1., 1., -1.])
886893
887894
"""
895+
warnings.warn(
896+
"torchhd.hard_quantize is deprecated, consider using torchhd.normalize instead.",
897+
DeprecationWarning,
898+
stacklevel=2,
899+
)
900+
888901
# Make sure that the output tensor has the same dtype and device
889902
# as the input tensor.
890903
positive = torch.tensor(1.0, dtype=input.dtype, device=input.device)
@@ -893,6 +906,35 @@ def hard_quantize(input: Tensor):
893906
return torch.where(input > 0, positive, negative)
894907

895908

909+
def normalize(input: VSATensor) -> VSATensor:
910+
"""Normalize the input hypervectors.
911+
912+
Args:
913+
input (Tensor): input tensor
914+
915+
Shapes:
916+
- Input: :math:`(*)`
917+
- Output: :math:`(*)`
918+
919+
Examples::
920+
921+
>>> x = torchhd.random(4, 10, "MAP").multibundle()
922+
>>> x
923+
MAPTensor([ 0., 0., -2., -2., 2., -2., 2., 2., 2., 0.])
924+
>>> torchhd.normalize(x)
925+
MAPTensor([-1., -1., -1., -1., 1., -1., 1., 1., 1., -1.])
926+
927+
>>> x = torchhd.random(4, 10, "HRR").multibundle()
928+
>>> x
929+
HRRTensor([-0.2999, 0.4686, 0.1797, -0.4830, 0.2718, -0.3663, 0.3079, 0.2558, -1.5157, -0.5196])
930+
>>> torchhd.normalize(x)
931+
HRRTensor([-0.1601, 0.2501, 0.0959, -0.2578, 0.1451, -0.1955, 0.1643, 0.1365, -0.8089, -0.2773])
932+
933+
"""
934+
input = ensure_vsa_tensor(input)
935+
return input.normalize()
936+
937+
896938
def dot_similarity(input: VSATensor, others: VSATensor, **kwargs) -> VSATensor:
897939
"""Dot product between the input vector and each vector in others.
898940
@@ -1037,6 +1079,11 @@ def multiset(input: VSATensor) -> VSATensor:
10371079
10381080
\bigoplus_{i=0}^{n-1} V_i
10391081
1082+
.. note::
1083+
1084+
This operation does not normalize the resulting or intermediate hypervectors.
1085+
Normalized hypervectors can be obtained with :func:`~torchhd.normalize`.
1086+
10401087
Args:
10411088
input (VSATensor): input hypervector tensor
10421089

torchhd/memory.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,22 @@ def read(self, query: Tensor) -> VSATensor:
121121
122122
"""
123123
# first dims from query, last dim from value
124-
out_shape = (*query.shape[:-1], self.value_dim)
124+
out_shape = tuple(query.shape[:-1]) + (self.value_dim,)
125125

126126
if query.dim() == 1:
127127
query = query.unsqueeze(0)
128128

129-
# make sure to have at least two dimension for index_add_
130-
intermediate_shape = (*query.shape[:-1], self.value_dim)
129+
intermediate_shape = tuple(query.shape[:-1]) + (self.value_dim,)
131130

132131
similarity = query @ self.keys.T
133132
is_active = similarity >= self.threshold
134133

135-
# sparse matrix-vector multiplication
136-
r_indices, v_indices = is_active.nonzero().T
137-
read = query.new_zeros(intermediate_shape)
138-
read.index_add_(0, r_indices, self.values[v_indices])
139-
return read.view(out_shape)
134+
# Sparse matrix-vector multiplication.
135+
to_indices, from_indices = is_active.nonzero().T
136+
137+
read = torch.zeros(intermediate_shape, dtype=query.dtype, device=query.device)
138+
read.index_add_(0, to_indices, self.values[from_indices])
139+
return read.view(out_shape).as_subclass(functional.MAPTensor)
140140

141141
@torch.no_grad()
142142
def write(self, keys: Tensor, values: Tensor) -> None:
@@ -161,7 +161,7 @@ def write(self, keys: Tensor, values: Tensor) -> None:
161161
similarity = keys @ self.keys.T
162162
is_active = similarity >= self.threshold
163163

164-
# sparse outer product and addition
164+
# Sparse outer product and addition.
165165
from_indices, to_indices = is_active.nonzero().T
166166
self.values.index_add_(0, to_indices, values[from_indices])
167167

torchhd/tensors/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
# SOFTWARE.
2323
#
24-
from typing import List, Set, Any
24+
from typing import List, Set
2525
import torch
2626
from torch import Tensor
2727

@@ -131,6 +131,10 @@ def permute(self, shifts: int = 1) -> "VSATensor":
131131
"""Permute the hypervector"""
132132
raise NotImplementedError
133133

134+
def normalize(self) -> "VSATensor":
135+
"""Normalize the hypervector"""
136+
raise NotImplementedError
137+
134138
def dot_similarity(self, others: "VSATensor") -> Tensor:
135139
"""Inner product with other hypervectors"""
136140
raise NotImplementedError

torchhd/tensors/bsbc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,26 @@ def permute(self, shifts: int = 1) -> "BSBCTensor":
335335
"""
336336
return torch.roll(self, shifts=shifts, dims=-1)
337337

338+
def normalize(self) -> "BSBCTensor":
339+
r"""Normalize the hypervector.
340+
341+
Each operation on BSBC hypervectors ensures it remains normalized, so this returns a copy of self.
342+
343+
Shapes:
344+
- Self: :math:`(*)`
345+
- Output: :math:`(*)`
346+
347+
Examples::
348+
349+
>>> x = torchhd.BSBCTensor.random(4, 6, block_size=64).multibundle()
350+
>>> x
351+
BSBCTensor([28, 27, 20, 44, 57, 18])
352+
>>> x.normalize()
353+
BSBCTensor([28, 27, 20, 44, 57, 18])
354+
355+
"""
356+
return self.clone()
357+
338358
def dot_similarity(self, others: "BSBCTensor", *, dtype=None) -> Tensor:
339359
"""Inner product with other hypervectors"""
340360
if dtype is None:

torchhd/tensors/bsc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,26 @@ def permute(self, shifts: int = 1) -> "BSCTensor":
426426
"""
427427
return super().roll(shifts=shifts, dims=-1)
428428

429+
def normalize(self) -> "BSCTensor":
430+
r"""Normalize the hypervector.
431+
432+
Each operation on BSC hypervectors ensures it remains normalized, so this returns a copy of self.
433+
434+
Shapes:
435+
- Self: :math:`(*)`
436+
- Output: :math:`(*)`
437+
438+
Examples::
439+
440+
>>> x = torchhd.BSCTensor.random(4, 6).multibundle()
441+
>>> x
442+
BSCTensor([ True, False, False, False, False, False])
443+
>>> x.normalize()
444+
BSCTensor([ True, False, False, False, False, False])
445+
446+
"""
447+
return self.clone()
448+
429449
def dot_similarity(self, others: "BSCTensor", *, dtype=None) -> Tensor:
430450
"""Inner product with other hypervectors."""
431451
device = self.device

torchhd/tensors/fhrr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,29 @@ def permute(self, shifts: int = 1) -> "FHRRTensor":
375375
"""
376376
return torch.roll(self, shifts=shifts, dims=-1)
377377

378+
def normalize(self) -> "FHRRTensor":
379+
r"""Normalize the hypervector.
380+
381+
The normalization preserves the element phase but sets the magnitude to one.
382+
383+
Shapes:
384+
- Self: :math:`(*)`
385+
- Output: :math:`(*)`
386+
387+
Examples::
388+
389+
>>> x = torchhd.FHRRTensor.random(4, 6).multibundle()
390+
>>> x
391+
FHRRTensor([ 1.0878+0.9382j, 2.0057-1.5603j, -2.2828-1.4410j, 1.9643-1.8269j,
392+
-0.9710-0.0120j, -0.7432+0.6956j])
393+
>>> x.normalize()
394+
FHRRTensor([ 0.7572+0.6531j, 0.7893-0.6140j, -0.8456-0.5338j, 0.7322-0.6810j,
395+
-0.9999-0.0124j, -0.7301+0.6833j])
396+
397+
"""
398+
angle = self.angle()
399+
return torch.complex(angle.cos(), angle.sin())
400+
378401
def dot_similarity(self, others: "FHRRTensor") -> Tensor:
379402
"""Inner product with other hypervectors"""
380403
if others.dim() >= 2:

torchhd/tensors/hrr.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
from torch import Tensor
2727
from torch.fft import fft, ifft
28+
import torch.nn.functional as F
2829
import math
2930

3031
from torchhd.tensors.base import VSATensor
@@ -155,7 +156,7 @@ def random(
155156
) -> "HRRTensor":
156157
"""Creates a set of random independent hypervectors.
157158
158-
The resulting hypervectors are sampled at random from a normal with mean 0 and standard deviation 1/dimensions.
159+
The resulting hypervectors are sampled uniformly at random from the (dimensions - 1)-unit sphere.
159160
160161
Args:
161162
num_vectors (int): the number of hypervectors to generate.
@@ -186,8 +187,8 @@ def random(
186187
raise ValueError(f"{name} vectors must be one of dtype {options}.")
187188

188189
size = (num_vectors, dimensions)
189-
result = torch.empty(size, dtype=dtype, device=device)
190-
result.normal_(0, 1.0 / math.sqrt(dimensions), generator=generator)
190+
result = torch.randn(size, dtype=dtype, device=device, generator=generator)
191+
result = F.normalize(result, p=2, dim=-1)
191192

192193
result.requires_grad = requires_grad
193194
return result.as_subclass(cls)
@@ -362,6 +363,27 @@ def permute(self, shifts: int = 1) -> "HRRTensor":
362363
"""
363364
return torch.roll(self, shifts=shifts, dims=-1)
364365

366+
def normalize(self) -> "HRRTensor":
367+
r"""Normalize the hypervector.
368+
369+
The normalization preserves the direction of the hypervector but makes it unit norm.
370+
This means that it is mapped to the closest point on the unit sphere.
371+
372+
Shapes:
373+
- Self: :math:`(*)`
374+
- Output: :math:`(*)`
375+
376+
Examples::
377+
378+
>>> x = torchhd.HRRTensor.random(4, 6).multibundle()
379+
>>> x
380+
HRRTensor([-0.6150, 0.4260, 0.6975, 0.3110, 0.9387, 0.0696])
381+
>>> x.normalize()
382+
HRRTensor([-0.4317, 0.2990, 0.4897, 0.2184, 0.6590, 0.0489])
383+
384+
"""
385+
return F.normalize(self, p=2, dim=-1)
386+
365387
def dot_similarity(self, others: "HRRTensor") -> Tensor:
366388
"""Inner product with other hypervectors"""
367389
if others.dim() >= 2:

torchhd/tensors/map.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#
2424
import torch
2525
from torch import Tensor
26-
import torch.nn.functional as F
2726
from typing import Set
2827

2928
from torchhd.tensors.base import VSATensor
@@ -38,8 +37,6 @@ class MAPTensor(VSATensor):
3837
supported_dtypes: Set[torch.dtype] = {
3938
torch.float32,
4039
torch.float64,
41-
torch.complex64,
42-
torch.complex128,
4340
torch.int8,
4441
torch.int16,
4542
torch.int32,
@@ -318,6 +315,30 @@ def permute(self, shifts: int = 1) -> "MAPTensor":
318315
"""
319316
return torch.roll(self, shifts=shifts, dims=-1)
320317

318+
def normalize(self) -> "MAPTensor":
319+
r"""Normalize the hypervector.
320+
321+
The normalization sets all positive entries to +1 and all other entries to -1.
322+
323+
Shapes:
324+
- Self: :math:`(*)`
325+
- Output: :math:`(*)`
326+
327+
Examples::
328+
329+
>>> x = torchhd.MAPTensor.random(4, 6).multibundle()
330+
>>> x
331+
MAPTensor([-2., -4., 4., 0., 4., -2.])
332+
>>> x.normalize()
333+
MAPTensor([-1., -1., 1., -1., 1., -1.])
334+
335+
"""
336+
# Ensure that the output tensor has the same dtype and device as the self tensor.
337+
positive = torch.tensor(1.0, dtype=self.dtype, device=self.device)
338+
negative = torch.tensor(-1.0, dtype=self.dtype, device=self.device)
339+
340+
return torch.where(self > 0, positive, negative)
341+
321342
def clipping(self, kappa) -> "MAPTensor":
322343
r"""Performs the clipping function that clips the lower and upper values.
323344

0 commit comments

Comments
 (0)