diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 429b5c348..02c5b3aa6 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -233,7 +233,7 @@ func (r *Runner) Run(ctx context.Context) error { runtime.SetBlockProfileRate(1) } - err = r.parsePluginsConfiguration(ctx) + err = r.parsePluginsConfiguration(ctx, datastore) if err != nil { setupLog.Error(err, "Failed to parse plugins configuration") return err @@ -310,7 +310,7 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory) } -func (r *Runner) parsePluginsConfiguration(ctx context.Context) error { +func (r *Runner) parsePluginsConfiguration(ctx context.Context, ds datastore.Datastore) error { if *configText == "" && *configFile == "" { return nil // configuring through code, not through file } @@ -329,8 +329,9 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context) error { } r.registerInTreePlugins() - handle := plugins.NewEppHandle(ctx) + handle := plugins.NewEppHandle(ctx, ds.PodList) config, err := loader.LoadConfig(configBytes, handle, logger) + if err != nil { return fmt.Errorf("failed to load the configuration - %w", err) } diff --git a/pkg/epp/plugins/handle.go b/pkg/epp/plugins/handle.go index 8c9153cf1..c074e9076 100644 --- a/pkg/epp/plugins/handle.go +++ b/pkg/epp/plugins/handle.go @@ -19,6 +19,8 @@ package plugins import ( "context" "fmt" + + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) // Handle provides plugins a set of standard data and tools to work with @@ -27,6 +29,9 @@ type Handle interface { Context() context.Context HandlePlugins + + // PodList lists pods matching the given predicate. + PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics } // HandlePlugins defines a set of APIs to work with instantiated plugins @@ -44,10 +49,14 @@ type HandlePlugins interface { GetAllPluginsWithNames() map[string]Plugin } +// PodListFunc is a function type that filters and returns a list of pod metrics +type PodListFunc func(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + // eppHandle is an implementation of the interface plugins.Handle type eppHandle struct { ctx context.Context HandlePlugins + podList PodListFunc } // Context returns a context the plugins can use, if they need one @@ -84,12 +93,18 @@ func (h *eppHandlePlugins) GetAllPluginsWithNames() map[string]Plugin { return h.plugins } -func NewEppHandle(ctx context.Context) Handle { +// PodList lists pods matching the given predicate. +func (h *eppHandle) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics { + return h.podList(predicate) +} + +func NewEppHandle(ctx context.Context, podList PodListFunc) Handle { return &eppHandle{ ctx: ctx, HandlePlugins: &eppHandlePlugins{ plugins: map[string]Plugin{}, }, + podList: podList, } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index bd9e2c96e..8b68132dc 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -149,3 +149,35 @@ func (i *indexer) reportLRUSize(ctx context.Context, interval time.Duration) { i.mu.RUnlock() } } + +// RemovePod removes a pod and its associated entries from the indexer. +func (i *indexer) RemovePod(pod ServerID) { + i.mu.RLock() + lruCache, exists := i.podToLRU[pod] + i.mu.RUnlock() + + if !exists { + return + } + + // Remove all hashes associated with the pod from hashToPods (triggers eviction callbacks). + for _, hash := range lruCache.Keys() { + lruCache.Remove(hash) + } + + i.mu.Lock() + delete(i.podToLRU, pod) + i.mu.Unlock() +} + +// Pods returns the list of all pods currently tracked in the indexer. +func (i *indexer) Pods() []ServerID { + i.mu.RLock() + defer i.mu.RUnlock() + + pods := make([]ServerID, 0, len(i.podToLRU)) + for pod := range i.podToLRU { + pods = append(pods, pod) + } + return pods +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go index 6d4fcc5f4..c35af8e27 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go @@ -46,3 +46,63 @@ func TestIndexer_AddAndGet(t *testing.T) { servers = i.Get(BlockHash(4)) assert.Empty(t, servers, "Cache should not contain non-existent hash") } + +func TestIndexer_RemovePodAndEviction(t *testing.T) { + const indexerSize = 10 + + i := newIndexer(context.Background(), indexerSize) + + server1 := ServerID{Namespace: "default", Name: "server1"} + server2 := ServerID{Namespace: "default", Name: "server2"} + + // Add indexerSize hashes to both servers + var hashes []BlockHash + for j := 0; j < indexerSize; j++ { + h := BlockHash(j) + hashes = append(hashes, h) + i.Add([]BlockHash{h}, server1) + i.Add([]BlockHash{h}, server2) + } + + // Ensure all entries are added + assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 should have 10 entries") + assert.Equal(t, indexerSize, i.podToLRU[server2].Len(), "server2 should have 10 entries") + + // Ensure each hash in hashToPods maps to both server1 and server2 + for _, h := range hashes { + pods := i.hashToPods[h] + assert.Len(t, pods, 2, "Each hash should be associated with exactly 2 pods") + assert.Contains(t, pods, server1, "hash should be associated with server1") + assert.Contains(t, pods, server2, "hash should be associated with server2") + } + + // Add indexerSize hash to server1 → should evict BlockHash(0) + evictedHash := BlockHash(0) + newHash := BlockHash(indexerSize) + i.Add([]BlockHash{newHash}, server1) + + // server1 LRU should still be at max capacity + assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 LRU should maintain max size") + + // BlockHash(0) should no longer have server1 in hashToPods + pods := i.Get(evictedHash) + assert.NotContains(t, pods, server1, "server1 should be evicted from hashToPods for hash 0") + assert.Contains(t, pods, server2, "server2 should still have hash 0") + + // Remove server2 + i.RemovePod(server2) + + // hashToPods for hash 0 should now be empty + pods = i.Get(evictedHash) + assert.NotContains(t, pods, server2, "server2 should be removed from hash 0") + assert.Empty(t, pods, "hash 0 should have no pods after both eviction and removal") + + // All remaining hashes should map only to server1 + for hash, pods := range i.hashToPods { + assert.Len(t, pods, 1, "hash %v should have only 1 pod after server2 removal", hash) + assert.Contains(t, pods, server1, "hash %v should only contain server1", hash) + } + + // Ensure hashToPods contains exactly indexerSize hashes (post-eviction and server2 removal) + assert.Len(t, i.hashToPods, indexerSize, "hashToPods should contain %d hashes after cleanup", indexerSize) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 4e0416720..a5a957db2 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -21,11 +21,13 @@ import ( "encoding/binary" "encoding/json" "fmt" + "time" "github.com/cespare/xxhash/v2" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" @@ -56,6 +58,10 @@ const ( PrefixCachePluginType = "prefix-cache-scorer" ) +const ( + PodActiveCheckInterval = 2 * time.Minute +) + var DefaultConfig = Config{ HashBlockSize: DefaultHashBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, @@ -86,6 +92,8 @@ type podSet map[ServerID]struct{} type Indexer interface { Get(hash BlockHash) podSet Add(hashes []BlockHash, server ServerID) + RemovePod(server ServerID) + Pods() []ServerID } // BlockHash is a hash of the block of request body. @@ -140,7 +148,9 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle } } - return New(handle.Context(), parameters).WithName(name), nil + p := New(handle.Context(), parameters).WithName(name) + go p.CleanUpInactivePods(handle.Context(), handle) + return p, nil } // New initializes a new prefix Plugin and returns its pointer. @@ -246,6 +256,33 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map return res } +// CleanUpInactivePods starts a goroutine that watches for inactive pods. +func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) { + logger := log.FromContext(ctx).V(logutil.VERBOSE) + ticker := time.NewTicker(PodActiveCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + activePodMetrics := handle.PodList(func(_ backendmetrics.PodMetrics) bool { return true }) + activePods := make(map[ServerID]struct{}, len(activePodMetrics)) + for _, pm := range activePodMetrics { + activePods[ServerID(pm.GetPod().NamespacedName)] = struct{}{} + } + + for _, pod := range m.indexer.Pods() { + if _, ok := activePods[pod]; !ok { + m.indexer.RemovePod(pod) + logger.Info("Removed pod not in active set", "pod", pod) + } + } + } + } +} + // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. // hash(0) is the hash of the model name, since different models generally don't share prefix cache. // For block i, hash(i) = hash(block i content, hash(i-1)). diff --git a/test/utils/handle.go b/test/utils/handle.go index 4a29dda87..273539f81 100644 --- a/test/utils/handle.go +++ b/test/utils/handle.go @@ -19,6 +19,7 @@ package utils import ( "context" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) @@ -33,6 +34,10 @@ func (h *testHandle) Context() context.Context { return h.ctx } +func (h *testHandle) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics { + return []backendmetrics.PodMetrics{} +} + type testHandlePlugins struct { plugins map[string]plugins.Plugin }