Skip to content

Commit e942b4f

Browse files
authored
Refactor API so that keys are generic (#6)
1 parent 34992a1 commit e942b4f

File tree

8 files changed

+299
-298
lines changed

8 files changed

+299
-298
lines changed

README.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ go get github.com/coder/hnsw@main
3232
```
3333

3434
```go
35-
g := hnsw.NewGraph[hnsw.Vector]()
35+
g := hnsw.NewGraph[int]()
3636
g.Add(
37-
hnsw.MakeVector("1", []float32{1, 1, 1}),
38-
hnsw.MakeVector("2", []float32{1, -1, 0.999}),
39-
hnsw.MakeVector("3", []float32{1, 0, -0.5}),
37+
hnsw.MakeNode(1, []float32{1, 1, 1}),
38+
hnsw.MakeNode(2, []float32{1, -1, 0.999}),
39+
hnsw.MakeNode(3, []float32{1, 0, -0.5}),
4040
)
4141

4242
neighbors := g.Search(
4343
[]float32{0.5, 0.5, 0.5},
4444
1,
4545
)
46-
fmt.Printf("best friend: %v\n", neighbors[0].Embedding())
46+
fmt.Printf("best friend: %v\n", neighbors[0].Vec)
4747
// Output: best friend: [1 1 1]
4848
```
4949

@@ -59,13 +59,13 @@ If you're using a single file as the backend, hnsw provides a convenient `SavedG
5959

6060
```go
6161
path := "some.graph"
62-
g1, err := LoadSavedGraph[hnsw.Vector](path)
62+
g1, err := LoadSavedGraph[int](path)
6363
if err != nil {
6464
panic(err)
6565
}
6666
// Insert some vectors
6767
for i := 0; i < 128; i++ {
68-
g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)}))
68+
g1.Add(hnsw.MakeNode(i, []float32{float32(i)}))
6969
}
7070

7171
// Save to disk
@@ -76,7 +76,7 @@ if err != nil {
7676

7777
// Later...
7878
// g2 is a copy of g1
79-
g2, err := LoadSavedGraph[Vector](path)
79+
g2, err := LoadSavedGraph[int](path)
8080
if err != nil {
8181
panic(err)
8282
}
@@ -94,10 +94,10 @@ nearly at disk speed. On my M3 Macbook I get these benchmark results:
9494
goos: darwin
9595
goarch: arm64
9696
pkg: github.com/coder/hnsw
97-
BenchmarkGraph_Import-16 2733 369803 ns/op 228.65 MB/s 352041 B/op 9880 allocs/op
98-
BenchmarkGraph_Export-16 6046 194441 ns/op 1076.65 MB/s 261854 B/op 3760 allocs/op
97+
BenchmarkGraph_Import-16 4029 259927 ns/op 796.85 MB/s 496022 B/op 3212 allocs/op
98+
BenchmarkGraph_Export-16 7042 168028 ns/op 1232.49 MB/s 239886 B/op 2388 allocs/op
9999
PASS
100-
ok github.com/coder/hnsw 2.530s
100+
ok github.com/coder/hnsw 2.624s
101101
```
102102

103103
when saving/loading a graph of 100 vectors with 256 dimensions.
@@ -130,18 +130,18 @@ $$
130130

131131
where:
132132
* $n$ is the number of vectors in the graph
133-
* $\text{size(id)}$ is the average size of the ID in bytes
133+
* $\text{size(key)}$ is the average size of the key in bytes
134134
* $M$ is the maximum number of neighbors each node can have
135135
* $d$ is the dimensionality of the vectors
136136
* $mem_{graph}$ is the memory used by the graph structure across all layers
137137
* $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer
138138

139139
You can infer that:
140-
* Connectivity ($M$) is very expensive if IDs are large
141-
* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data
142-
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure
140+
* Connectivity ($M$) is very expensive if keys are large
141+
* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data
142+
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure
143143

144-
In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes:
144+
In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes:
145145

146146
* $256 \cdot 4 = 1024$ data bytes
147147
* $16 \cdot 8 = 128$ metadata bytes

analyzer.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package hnsw
22

3+
import "cmp"
4+
35
// Analyzer is a struct that holds a graph and provides
46
// methods for analyzing it. It offers no compatibility guarantee
57
// as the methods of measuring the graph's health with change
68
// with the implementation.
7-
type Analyzer[T Embeddable] struct {
8-
Graph *Graph[T]
9+
type Analyzer[K cmp.Ordered] struct {
10+
Graph *Graph[K]
911
}
1012

1113
func (a *Analyzer[T]) Height() int {
@@ -17,16 +19,16 @@ func (a *Analyzer[T]) Height() int {
1719
func (a *Analyzer[T]) Connectivity() []float64 {
1820
var layerConnectivity []float64
1921
for _, layer := range a.Graph.layers {
20-
if len(layer.Nodes) == 0 {
22+
if len(layer.nodes) == 0 {
2123
continue
2224
}
2325

2426
var sum float64
25-
for _, node := range layer.Nodes {
27+
for _, node := range layer.nodes {
2628
sum += float64(len(node.neighbors))
2729
}
2830

29-
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes)))
31+
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes)))
3032
}
3133

3234
return layerConnectivity
@@ -36,7 +38,7 @@ func (a *Analyzer[T]) Connectivity() []float64 {
3638
func (a *Analyzer[T]) Topography() []int {
3739
var topography []int
3840
for _, layer := range a.Graph.layers {
39-
topography = append(topography, len(layer.Nodes))
41+
topography = append(topography, len(layer.nodes))
4042
}
4143
return topography
4244
}

encode.go

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package hnsw
22

33
import (
44
"bufio"
5+
"cmp"
56
"encoding/binary"
67
"fmt"
78
"io"
@@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) {
4344
*v = string(s)
4445
return len(s), err
4546

47+
case *[]float32:
48+
var ln int
49+
_, err := binaryRead(r, &ln)
50+
if err != nil {
51+
return 0, err
52+
}
53+
54+
*v = make([]float32, ln)
55+
return binary.Size(*v), binary.Read(r, byteOrder, *v)
56+
4657
case io.ReaderFrom:
4758
n, err := v.ReadFrom(r)
4859
return int(n), err
@@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) {
7384
}
7485

7586
return n + n2, nil
87+
case []float32:
88+
n, err := binaryWrite(w, len(v))
89+
if err != nil {
90+
return n, err
91+
}
92+
return n + binary.Size(v), binary.Write(w, byteOrder, v)
7693

7794
default:
7895
sz := binary.Size(data)
@@ -113,7 +130,7 @@ const encodingVersion = 1
113130
// Export writes the graph to a writer.
114131
//
115132
// T must implement io.WriterTo.
116-
func (h *Graph[T]) Export(w io.Writer) error {
133+
func (h *Graph[K]) Export(w io.Writer) error {
117134
distFuncName, ok := distanceFuncToName(h.Distance)
118135
if !ok {
119136
return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance)
@@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error {
134151
return fmt.Errorf("encode number of layers: %w", err)
135152
}
136153
for _, layer := range h.layers {
137-
_, err = binaryWrite(w, len(layer.Nodes))
154+
_, err = binaryWrite(w, len(layer.nodes))
138155
if err != nil {
139156
return fmt.Errorf("encode number of nodes: %w", err)
140157
}
141-
for _, node := range layer.Nodes {
142-
_, err = binaryWrite(w, node.Point)
158+
for _, node := range layer.nodes {
159+
_, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors))
143160
if err != nil {
144-
return fmt.Errorf("encode node point: %w", err)
145-
}
146-
147-
if _, err = binaryWrite(w, len(node.neighbors)); err != nil {
148-
return fmt.Errorf("encode number of neighbors: %w", err)
161+
return fmt.Errorf("encode node data: %w", err)
149162
}
150163

151164
for neighbor := range node.neighbors {
152165
_, err = binaryWrite(w, neighbor)
153166
if err != nil {
154-
return fmt.Errorf("encode neighbor %q: %w", neighbor, err)
167+
return fmt.Errorf("encode neighbor %v: %w", neighbor, err)
155168
}
156169
}
157170
}
@@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error {
164177
// T must implement io.ReaderFrom.
165178
// The imported graph does not have to match the exported graph's parameters (except for
166179
// dimensionality). The graph will converge onto the new parameters.
167-
func (h *Graph[T]) Import(r io.Reader) error {
180+
func (h *Graph[K]) Import(r io.Reader) error {
168181
var (
169182
version int
170183
dist string
@@ -195,55 +208,54 @@ func (h *Graph[T]) Import(r io.Reader) error {
195208
return err
196209
}
197210

198-
h.layers = make([]*layer[T], nLayers)
211+
h.layers = make([]*layer[K], nLayers)
199212
for i := 0; i < nLayers; i++ {
200213
var nNodes int
201214
_, err = binaryRead(r, &nNodes)
202215
if err != nil {
203216
return err
204217
}
205218

206-
nodes := make(map[string]*layerNode[T], nNodes)
219+
nodes := make(map[K]*layerNode[K], nNodes)
207220
for j := 0; j < nNodes; j++ {
208-
var point T
209-
_, err = binaryRead(r, &point)
210-
if err != nil {
211-
return fmt.Errorf("decoding node %d: %w", j, err)
212-
}
213-
221+
var key K
222+
var vec Vector
214223
var nNeighbors int
215-
_, err = binaryRead(r, &nNeighbors)
224+
_, err = multiBinaryRead(r, &key, &vec, &nNeighbors)
216225
if err != nil {
217-
return fmt.Errorf("decoding number of neighbors for node %d: %w", j, err)
226+
return fmt.Errorf("decoding node %d: %w", j, err)
218227
}
219228

220-
neighbors := make([]string, nNeighbors)
229+
neighbors := make([]K, nNeighbors)
221230
for k := 0; k < nNeighbors; k++ {
222-
var neighbor string
231+
var neighbor K
223232
_, err = binaryRead(r, &neighbor)
224233
if err != nil {
225234
return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err)
226235
}
227236
neighbors[k] = neighbor
228237
}
229238

230-
node := &layerNode[T]{
231-
Point: point,
232-
neighbors: make(map[string]*layerNode[T]),
239+
node := &layerNode[K]{
240+
Node: Node[K]{
241+
Key: key,
242+
Value: vec,
243+
},
244+
neighbors: make(map[K]*layerNode[K]),
233245
}
234246

235-
nodes[point.ID()] = node
247+
nodes[key] = node
236248
for _, neighbor := range neighbors {
237249
node.neighbors[neighbor] = nil
238250
}
239251
}
240252
// Fill in neighbor pointers
241253
for _, node := range nodes {
242-
for id := range node.neighbors {
243-
node.neighbors[id] = nodes[id]
254+
for key := range node.neighbors {
255+
node.neighbors[key] = nodes[key]
244256
}
245257
}
246-
h.layers[i] = &layer[T]{Nodes: nodes}
258+
h.layers[i] = &layer[K]{nodes: nodes}
247259
}
248260

249261
return nil
@@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error {
253265
// changes to a file upon calls to Save. It is more convenient
254266
// but less powerful than calling Graph.Export and Graph.Import
255267
// directly.
256-
type SavedGraph[T Embeddable] struct {
257-
*Graph[T]
268+
type SavedGraph[K cmp.Ordered] struct {
269+
*Graph[K]
258270
Path string
259271
}
260272

@@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct {
265277
//
266278
// It does not hold open a file descriptor, so SavedGraph can be forgotten
267279
// without ever calling Save.
268-
func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
280+
func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) {
269281
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
270282
if err != nil {
271283
return nil, err
@@ -276,15 +288,15 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
276288
return nil, err
277289
}
278290

279-
g := NewGraph[T]()
291+
g := NewGraph[K]()
280292
if info.Size() > 0 {
281293
err = g.Import(bufio.NewReader(f))
282294
if err != nil {
283295
return nil, fmt.Errorf("import: %w", err)
284296
}
285297
}
286298

287-
return &SavedGraph[T]{Graph: g, Path: path}, nil
299+
return &SavedGraph[K]{Graph: g, Path: path}, nil
288300
}
289301

290302
// Save writes the graph to the file.

0 commit comments

Comments
 (0)