1
1
package hnsw
2
2
3
3
import (
4
+ "cmp"
4
5
"fmt"
5
6
"math"
6
7
"math/rand"
@@ -14,28 +15,28 @@ import (
14
15
type Embedding = []float32
15
16
16
17
// Embeddable describes a type that can be embedded in a HNSW graph.
17
- type Embeddable interface {
18
+ type Embeddable [ K cmp. Ordered ] interface {
18
19
// ID returns a unique identifier for the object.
19
- ID () string
20
+ ID () K
20
21
// Embedding returns the embedding of the object.
21
22
// float32 is used for compatibility with OpenAI embeddings.
22
23
Embedding () Embedding
23
24
}
24
25
25
26
// layerNode is a node in a layer of the graph.
26
- type layerNode [T Embeddable ] struct {
27
- Point Embeddable
27
+ type layerNode [K cmp. Ordered , V Embeddable [ K ] ] struct {
28
+ Point Embeddable [ K ]
28
29
// neighbors is map of neighbor IDs to neighbor nodes.
29
30
// It is a map and not a slice to allow for efficient deletes, esp.
30
31
// when M is high.
31
- neighbors map [string ]* layerNode [T ]
32
+ neighbors map [K ]* layerNode [K , V ]
32
33
}
33
34
34
35
// addNeighbor adds a o neighbor to the node, replacing the neighbor
35
36
// with the worst distance if the neighbor set is full.
36
- func (n * layerNode [T ]) addNeighbor (newNode * layerNode [T ], m int , dist DistanceFunc ) {
37
+ func (n * layerNode [K , V ]) addNeighbor (newNode * layerNode [K , V ], m int , dist DistanceFunc ) {
37
38
if n .neighbors == nil {
38
- n .neighbors = make (map [string ]* layerNode [T ], m )
39
+ n .neighbors = make (map [K ]* layerNode [K , V ], m )
39
40
}
40
41
41
42
n .neighbors [newNode .Point .ID ()] = newNode
@@ -46,7 +47,7 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu
46
47
// Find the neighbor with the worst distance.
47
48
var (
48
49
worstDist = float32 (math .Inf (- 1 ))
49
- worst * layerNode [T ]
50
+ worst * layerNode [K , V ]
50
51
)
51
52
for _ , neighbor := range n .neighbors {
52
53
d := dist (neighbor .Point .Embedding (), n .Point .Embedding ())
@@ -64,39 +65,39 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu
64
65
worst .replenish (m )
65
66
}
66
67
67
- type searchCandidate [T Embeddable ] struct {
68
- node * layerNode [T ]
68
+ type searchCandidate [K cmp. Ordered , V Embeddable [ K ] ] struct {
69
+ node * layerNode [K , V ]
69
70
dist float32
70
71
}
71
72
72
- func (s searchCandidate [T ]) Less (o searchCandidate [T ]) bool {
73
+ func (s searchCandidate [K , V ]) Less (o searchCandidate [K , V ]) bool {
73
74
return s .dist < o .dist
74
75
}
75
76
76
77
// search returns the layer node closest to the target node
77
78
// within the same layer.
78
- func (n * layerNode [T ]) search (
79
+ func (n * layerNode [K , V ]) search (
79
80
// k is the number of candidates in the result set.
80
81
k int ,
81
82
efSearch int ,
82
83
target Embedding ,
83
84
distance DistanceFunc ,
84
- ) []searchCandidate [T ] {
85
+ ) []searchCandidate [K , V ] {
85
86
// This is a basic greedy algorithm to find the entry point at the given level
86
87
// that is closest to the target node.
87
- candidates := heap.Heap [searchCandidate [T ]]{}
88
- candidates .Init (make ([]searchCandidate [T ], 0 , efSearch ))
88
+ candidates := heap.Heap [searchCandidate [K , V ]]{}
89
+ candidates .Init (make ([]searchCandidate [K , V ], 0 , efSearch ))
89
90
candidates .Push (
90
- searchCandidate [T ]{
91
+ searchCandidate [K , V ]{
91
92
node : n ,
92
93
dist : distance (n .Point .Embedding (), target ),
93
94
},
94
95
)
95
96
var (
96
- result = heap.Heap [searchCandidate [T ]]{}
97
- visited = make (map [string ]bool )
97
+ result = heap.Heap [searchCandidate [K , V ]]{}
98
+ visited = make (map [K ]bool )
98
99
)
99
- result .Init (make ([]searchCandidate [T ], 0 , k ))
100
+ result .Init (make ([]searchCandidate [K , V ], 0 , k ))
100
101
101
102
// Begin with the entry node in the result set.
102
103
result .Push (candidates .Min ())
@@ -122,13 +123,13 @@ func (n *layerNode[T]) search(
122
123
dist := distance (neighbor .Point .Embedding (), target )
123
124
improved = improved || dist < result .Min ().dist
124
125
if result .Len () < k {
125
- result .Push (searchCandidate [T ]{node : neighbor , dist : dist })
126
+ result .Push (searchCandidate [K , V ]{node : neighbor , dist : dist })
126
127
} else if dist < result .Max ().dist {
127
128
result .PopLast ()
128
- result .Push (searchCandidate [T ]{node : neighbor , dist : dist })
129
+ result .Push (searchCandidate [K , V ]{node : neighbor , dist : dist })
129
130
}
130
131
131
- candidates .Push (searchCandidate [T ]{node : neighbor , dist : dist })
132
+ candidates .Push (searchCandidate [K , V ]{node : neighbor , dist : dist })
132
133
// Always store candidates if we haven't reached the limit.
133
134
if candidates .Len () > efSearch {
134
135
candidates .PopLast ()
@@ -145,7 +146,7 @@ func (n *layerNode[T]) search(
145
146
return result .Slice ()
146
147
}
147
148
148
- func (n * layerNode [T ]) replenish (m int ) {
149
+ func (n * layerNode [K , V ]) replenish (m int ) {
149
150
if len (n .neighbors ) >= m {
150
151
return
151
152
}
@@ -172,7 +173,7 @@ func (n *layerNode[T]) replenish(m int) {
172
173
173
174
// isolates remove the node from the graph by removing all connections
174
175
// to neighbors.
175
- func (n * layerNode [T ]) isolate (m int ) {
176
+ func (n * layerNode [K , V ]) isolate (m int ) {
176
177
for _ , neighbor := range n .neighbors {
177
178
delete (neighbor .neighbors , n .Point .ID ())
178
179
neighbor .replenish (m )
0 commit comments