@@ -22,6 +22,8 @@ import (
2222 "time"
2323 "unsafe"
2424
25+ "container/list"
26+
2527 "sigs.k8s.io/controller-runtime/pkg/log"
2628 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2729 logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -30,8 +32,8 @@ import (
3032func newIndexer (maxCacheSize int ) * indexer {
3133 t := & indexer {
3234 maxCacheSize : maxCacheSize ,
33- table : make (map [BlockHash ]map [ServerID ]* node ),
34- list : newLinkedList (),
35+ table : make (map [BlockHash ]map [ServerID ]* list. Element ),
36+ ll : list . New (),
3537 }
3638 go t .ReportCacheSize (time .Second )
3739 return t
@@ -42,8 +44,14 @@ func newIndexer(maxCacheSize int) *indexer {
4244type indexer struct {
4345 mu sync.RWMutex
4446 maxCacheSize int
45- table map [BlockHash ]map [ServerID ]* node // from any prefix cache to the cache entry to find the server
46- list * linkedList // LRU list to keep track of the order of entries
47+ table map [BlockHash ]map [ServerID ]* list.Element // from any prefix cache to the cache entry to find the server
48+ ll * list.List // LinkedList to keep track of the order of entries
49+ }
50+
51+ // value is the value stored in the linked list.
52+ type value struct {
53+ server ServerID
54+ hash BlockHash
4755}
4856
4957// Get returns the set of servers that have the given prefix hash cached.
@@ -68,49 +76,52 @@ func (i *indexer) Add(hashes []BlockHash, server ServerID) {
6876 }
6977}
7078
71- func (i * indexer ) check (hash BlockHash , server ServerID ) (* node , bool ) {
79+ func (i * indexer ) check (hash BlockHash , server ServerID ) (* list. Element , bool ) {
7280 servers , ok := i .table [hash ]
7381 if ! ok {
7482 return nil , false
7583 }
76- n , ok := servers [server ]
77- return n , ok
84+ e , ok := servers [server ]
85+ return e , ok
7886}
7987
8088func (i * indexer ) add (hash BlockHash , server ServerID ) {
81- node , exists := i .check (hash , server )
89+ e , exists := i .check (hash , server )
8290 if exists {
83- i .list . moveToTail ( node )
91+ i .ll . MoveToBack ( e )
8492 } else {
8593 i .create (hash , server )
8694 }
8795}
8896
8997func (i * indexer ) create (hash BlockHash , server ServerID ) {
90- n := & node {
91- hash : hash ,
92- server : server ,
93- }
94-
95- for i .list .size >= i .maxCacheSize {
98+ for i .ll .Len () >= i .maxCacheSize {
9699 // Evict the least recently used entry if we've exceeded the max cache size
97100 i .evict ()
98101 }
99102
100103 if _ , ok := i .table [hash ]; ! ok {
101- i .table [hash ] = make (map [ServerID ]* node )
104+ i .table [hash ] = make (map [ServerID ]* list. Element )
102105 }
103- i.table [hash ][server ] = n
104- i .list .add (n )
106+ v := & value {
107+ server : server ,
108+ hash : hash ,
109+ }
110+ e := i .ll .PushBack (v )
111+ i.table [hash ][server ] = e
105112}
106113
107114// evict removes the least recently used entry from the cache
108115func (i * indexer ) evict () {
109- oldestNode := i .list .dummyHead .next
110- i .list .delete (oldestNode )
116+ oldestNode := i .ll .Front ()
117+ if oldestNode == nil {
118+ return
119+ }
120+ i .ll .Remove (oldestNode )
111121
112- hash := oldestNode .hash
113- server := oldestNode .server
122+ v := oldestNode .Value .(* value )
123+ hash := v .hash
124+ server := v .server
114125 // Remove from the hash map
115126 serverMap := i .table [hash ]
116127 delete (serverMap , server )
@@ -129,8 +140,8 @@ func (i *indexer) ReportCacheSize(interval time.Duration) {
129140 defer ticker .Stop ()
130141 for range ticker .C {
131142 i .mu .RLock ()
132- metrics .RecordPrefixCacheSize (int64 (i .list . size ))
133- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" , "# entries" , i .list . size , "estimated size MB" , i .list . size * i .estimateEntrySize ()/ 1000000 )
143+ metrics .RecordPrefixCacheSize (int64 (i .ll . Len () ))
144+ log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" , "# entries" , i .ll . Len () , "estimated size MB" , i .ll . Len () * i .estimateEntrySize ()/ 1000000 )
134145 i .mu .RUnlock ()
135146 }
136147}
@@ -146,7 +157,7 @@ func (i *indexer) estimateEntrySize() int {
146157 // The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
147158 // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
148159 // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
149- size += int (unsafe .Sizeof (node {}))
160+ size += int (unsafe .Sizeof (value {}))
150161 // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
151162 size += 2 * 63
152163
0 commit comments