Skip to content

Commit 651f62e

Browse files
committed
fix: Add RemovePod to prefix indexer
Signed-off-by: Kfir Toledo <[email protected]>
1 parent a399f6d commit 651f62e

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,23 @@ func (i *indexer) ReportLRUSize(interval time.Duration) {
149149
i.mu.RUnlock()
150150
}
151151
}
152+
153+
// RemovePod removes a pod and its associated entries from the indexer.
154+
func (i *indexer) RemovePod(pod ServerID) {
155+
i.mu.RLock()
156+
lruCache, exists := i.podToLRU[pod]
157+
i.mu.RUnlock()
158+
159+
if !exists {
160+
return
161+
}
162+
163+
// Remove all hashes associated with the pod from hashToPods (triggers eviction callbacks).
164+
for _, hash := range lruCache.Keys() {
165+
lruCache.Remove(hash)
166+
}
167+
168+
i.mu.Lock()
169+
delete(i.podToLRU, pod)
170+
i.mu.Unlock()
171+
}

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,63 @@ func TestIndexer_AddAndGet(t *testing.T) {
4545
servers = i.Get(BlockHash(4))
4646
assert.Empty(t, servers, "Cache should not contain non-existent hash")
4747
}
48+
49+
func TestIndexer_RemovePodAndEviction(t *testing.T) {
50+
const indexerSize = 10
51+
52+
i := newIndexer(indexerSize)
53+
54+
server1 := ServerID{Namespace: "default", Name: "server1"}
55+
server2 := ServerID{Namespace: "default", Name: "server2"}
56+
57+
// Add indexerSize hashes to both servers
58+
var hashes []BlockHash
59+
for j := 0; j < indexerSize; j++ {
60+
h := BlockHash(j)
61+
hashes = append(hashes, h)
62+
i.Add([]BlockHash{h}, server1)
63+
i.Add([]BlockHash{h}, server2)
64+
}
65+
66+
// Ensure all entries are added
67+
assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 should have 10 entries")
68+
assert.Equal(t, indexerSize, i.podToLRU[server2].Len(), "server2 should have 10 entries")
69+
70+
// Ensure each hash in hashToPods maps to both server1 and server2
71+
for _, h := range hashes {
72+
pods := i.hashToPods[h]
73+
assert.Len(t, pods, 2, "Each hash should be associated with exactly 2 pods")
74+
assert.Contains(t, pods, server1, "hash should be associated with server1")
75+
assert.Contains(t, pods, server2, "hash should be associated with server2")
76+
}
77+
78+
// Add indexerSize hash to server1 → should evict BlockHash(0)
79+
evictedHash := BlockHash(0)
80+
newHash := BlockHash(indexerSize)
81+
i.Add([]BlockHash{newHash}, server1)
82+
83+
// server1 LRU should still be at max capacity
84+
assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 LRU should maintain max size")
85+
86+
// BlockHash(0) should no longer have server1 in hashToPods
87+
pods := i.Get(evictedHash)
88+
assert.NotContains(t, pods, server1, "server1 should be evicted from hashToPods for hash 0")
89+
assert.Contains(t, pods, server2, "server2 should still have hash 0")
90+
91+
// Remove server2
92+
i.RemovePod(server2)
93+
94+
// hashToPods for hash 0 should now be empty
95+
pods = i.Get(evictedHash)
96+
assert.NotContains(t, pods, server2, "server2 should be removed from hash 0")
97+
assert.Empty(t, pods, "hash 0 should have no pods after both eviction and removal")
98+
99+
// All remaining hashes should map only to server1
100+
for hash, pods := range i.hashToPods {
101+
assert.Len(t, pods, 1, "hash %v should have only 1 pod after server2 removal", hash)
102+
assert.Contains(t, pods, server1, "hash %v should only contain server1", hash)
103+
}
104+
105+
// Ensure hashToPods contains exactly indexerSize hashes (post-eviction and server2 removal)
106+
assert.Len(t, i.hashToPods, indexerSize, "hashToPods should contain %d hashes after cleanup", indexerSize)
107+
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/binary"
2222
"encoding/json"
2323
"fmt"
24+
"time"
2425

2526
"github.com/cespare/xxhash/v2"
2627
k8stypes "k8s.io/apimachinery/pkg/types"
@@ -55,6 +56,11 @@ const (
5556
PrefixCachePluginType = "prefix-cache-scorer"
5657
)
5758

59+
const (
60+
PodActiveCheckInterval = 1 * time.Minute
61+
PodInactivityTimeout = 5 * time.Minute
62+
)
63+
5864
var DefaultConfig = Config{
5965
HashBlockSize: DefaultHashBlockSize,
6066
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
@@ -84,6 +90,7 @@ type podSet map[ServerID]struct{}
8490
type Indexer interface {
8591
Get(hash BlockHash) podSet
8692
Add(hashes []BlockHash, server ServerID)
93+
RemovePod(server ServerID)
8794
}
8895

8996
// BlockHash is a hash of the block of request body.
@@ -125,7 +132,7 @@ var _ framework.Scorer = &Plugin{}
125132
var _ framework.PostCycle = &Plugin{}
126133

127134
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
128-
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
135+
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
129136
parameters := Config{
130137
HashBlockSize: DefaultHashBlockSize,
131138
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
@@ -138,7 +145,9 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
138145
}
139146
}
140147

141-
return New(parameters).WithName(name), nil
148+
p := New(parameters).WithName(name)
149+
go p.StartPodActiveWatcher(handle.Context(), handle)
150+
return p, nil
142151
}
143152

144153
// New initializes a new prefix Plugin and returns its pointer.
@@ -239,6 +248,45 @@ func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
239248
return res
240249
}
241250

251+
// StartPodActiveWatcher starts a goroutine that watches for active pods.
252+
func (m *Plugin) StartPodActiveWatcher(ctx context.Context, handle plugins.Handle) {
253+
logger := log.FromContext(ctx).V(logutil.VERBOSE)
254+
255+
ticker := time.NewTicker(PodActiveCheckInterval)
256+
defer ticker.Stop()
257+
258+
podLastSeen := make(map[ServerID]time.Time)
259+
260+
for {
261+
select {
262+
case <-ctx.Done():
263+
return
264+
case <-ticker.C:
265+
now := time.Now()
266+
activePods := handle.GetActivePods()
267+
268+
// Track active pods
269+
activeSet := make(map[ServerID]struct{}, len(activePods))
270+
for _, np := range activePods {
271+
id := ServerID(np)
272+
activeSet[id] = struct{}{}
273+
podLastSeen[id] = now
274+
}
275+
276+
// Remove stale pods
277+
for pod, lastSeen := range podLastSeen {
278+
if _, stillActive := activeSet[pod]; !stillActive {
279+
if now.Sub(lastSeen) > PodInactivityTimeout {
280+
m.indexer.RemovePod(pod)
281+
delete(podLastSeen, pod)
282+
logger.Info("Removed inactive pod from prefix cache", "pod", pod)
283+
}
284+
}
285+
}
286+
}
287+
}
288+
}
289+
242290
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
243291
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
244292
// For block i, hash(i) = hash(block i content, hash(i-1)).

0 commit comments

Comments
 (0)