Skip to content

Commit d6543ed

Browse files
committed
WIP: replace string key with cmp.Ordered generic
1 parent 34992a1 commit d6543ed

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

graph.go

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package hnsw
22

33
import (
4+
"cmp"
45
"fmt"
56
"math"
67
"math/rand"
@@ -14,28 +15,28 @@ import (
1415
type Embedding = []float32
1516

1617
// Embeddable describes a type that can be embedded in a HNSW graph.
17-
type Embeddable interface {
18+
type Embeddable[K cmp.Ordered] interface {
1819
// ID returns a unique identifier for the object.
19-
ID() string
20+
ID() K
2021
// Embedding returns the embedding of the object.
2122
// float32 is used for compatibility with OpenAI embeddings.
2223
Embedding() Embedding
2324
}
2425

2526
// layerNode is a node in a layer of the graph.
26-
type layerNode[T Embeddable] struct {
27-
Point Embeddable
27+
type layerNode[K cmp.Ordered, V Embeddable[K]] struct {
28+
Point Embeddable[K]
2829
// neighbors is map of neighbor IDs to neighbor nodes.
2930
// It is a map and not a slice to allow for efficient deletes, esp.
3031
// when M is high.
31-
neighbors map[string]*layerNode[T]
32+
neighbors map[K]*layerNode[K, V]
3233
}
3334

3435
// addNeighbor adds a o neighbor to the node, replacing the neighbor
3536
// with the worst distance if the neighbor set is full.
36-
func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFunc) {
37+
func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist DistanceFunc) {
3738
if n.neighbors == nil {
38-
n.neighbors = make(map[string]*layerNode[T], m)
39+
n.neighbors = make(map[K]*layerNode[K, V], m)
3940
}
4041

4142
n.neighbors[newNode.Point.ID()] = newNode
@@ -46,7 +47,7 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu
4647
// Find the neighbor with the worst distance.
4748
var (
4849
worstDist = float32(math.Inf(-1))
49-
worst *layerNode[T]
50+
worst *layerNode[K, V]
5051
)
5152
for _, neighbor := range n.neighbors {
5253
d := dist(neighbor.Point.Embedding(), n.Point.Embedding())
@@ -64,39 +65,39 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu
6465
worst.replenish(m)
6566
}
6667

67-
type searchCandidate[T Embeddable] struct {
68-
node *layerNode[T]
68+
type searchCandidate[K cmp.Ordered, V Embeddable[K]] struct {
69+
node *layerNode[K, V]
6970
dist float32
7071
}
7172

72-
func (s searchCandidate[T]) Less(o searchCandidate[T]) bool {
73+
func (s searchCandidate[K, V]) Less(o searchCandidate[K, V]) bool {
7374
return s.dist < o.dist
7475
}
7576

7677
// search returns the layer node closest to the target node
7778
// within the same layer.
78-
func (n *layerNode[T]) search(
79+
func (n *layerNode[K, V]) search(
7980
// k is the number of candidates in the result set.
8081
k int,
8182
efSearch int,
8283
target Embedding,
8384
distance DistanceFunc,
84-
) []searchCandidate[T] {
85+
) []searchCandidate[K, V] {
8586
// This is a basic greedy algorithm to find the entry point at the given level
8687
// that is closest to the target node.
87-
candidates := heap.Heap[searchCandidate[T]]{}
88-
candidates.Init(make([]searchCandidate[T], 0, efSearch))
88+
candidates := heap.Heap[searchCandidate[K, V]]{}
89+
candidates.Init(make([]searchCandidate[K, V], 0, efSearch))
8990
candidates.Push(
90-
searchCandidate[T]{
91+
searchCandidate[K, V]{
9192
node: n,
9293
dist: distance(n.Point.Embedding(), target),
9394
},
9495
)
9596
var (
96-
result = heap.Heap[searchCandidate[T]]{}
97-
visited = make(map[string]bool)
97+
result = heap.Heap[searchCandidate[K, V]]{}
98+
visited = make(map[K]bool)
9899
)
99-
result.Init(make([]searchCandidate[T], 0, k))
100+
result.Init(make([]searchCandidate[K, V], 0, k))
100101

101102
// Begin with the entry node in the result set.
102103
result.Push(candidates.Min())
@@ -122,13 +123,13 @@ func (n *layerNode[T]) search(
122123
dist := distance(neighbor.Point.Embedding(), target)
123124
improved = improved || dist < result.Min().dist
124125
if result.Len() < k {
125-
result.Push(searchCandidate[T]{node: neighbor, dist: dist})
126+
result.Push(searchCandidate[K, V]{node: neighbor, dist: dist})
126127
} else if dist < result.Max().dist {
127128
result.PopLast()
128-
result.Push(searchCandidate[T]{node: neighbor, dist: dist})
129+
result.Push(searchCandidate[K, V]{node: neighbor, dist: dist})
129130
}
130131

131-
candidates.Push(searchCandidate[T]{node: neighbor, dist: dist})
132+
candidates.Push(searchCandidate[K, V]{node: neighbor, dist: dist})
132133
// Always store candidates if we haven't reached the limit.
133134
if candidates.Len() > efSearch {
134135
candidates.PopLast()
@@ -145,7 +146,7 @@ func (n *layerNode[T]) search(
145146
return result.Slice()
146147
}
147148

148-
func (n *layerNode[T]) replenish(m int) {
149+
func (n *layerNode[K, V]) replenish(m int) {
149150
if len(n.neighbors) >= m {
150151
return
151152
}
@@ -172,7 +173,7 @@ func (n *layerNode[T]) replenish(m int) {
172173

173174
// isolates remove the node from the graph by removing all connections
174175
// to neighbors.
175-
func (n *layerNode[T]) isolate(m int) {
176+
func (n *layerNode[K, V]) isolate(m int) {
176177
for _, neighbor := range n.neighbors {
177178
delete(neighbor.neighbors, n.Point.ID())
178179
neighbor.replenish(m)

0 commit comments

Comments
 (0)