Skip to content

Commit 9a444e6

Browse files
committed
more refactoring
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent bc9a754 commit 9a444e6

File tree

13 files changed

+144
-93
lines changed

13 files changed

+144
-93
lines changed

cmd/epp/main.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ func run() error {
173173
&metrics.MetricsScraper{}: podinfo.NewScraperConfig(*refreshMetricsInterval, scrapeTimeout),
174174
&models.ModelsScraper{}: podinfo.NewScraperConfig(modelsScrapeInterval, scrapeTimeout),
175175
}
176-
podScraperFactory := podinfo.NewPodInfoFactory(podScrapers)
176+
podInfoFactory := podinfo.NewPodInfoFactory(podScrapers)
177177
// Setup runner.
178178
ctx := ctrl.SetupSignalHandler()
179179

180-
datastore := datastore.NewDatastore(ctx, podScraperFactory)
180+
datastore := datastore.NewDatastore(ctx, podInfoFactory)
181181

182182
scheduler := scheduling.NewScheduler(datastore)
183183
serverRunner := &runserver.ExtProcServerRunner{

pkg/epp/backend/metrics/logger.go

+20-18
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometh
7373
logger.V(logutil.DEFAULT).Info("Shutting down metrics logger thread")
7474
return
7575
case <-ticker.C:
76-
podsWithFreshMetrics := datastore.PodList(func(s podinfo.PodInfo) bool {
77-
metrics := getMetricsFromPodData(s.GetData())
78-
return time.Since(metrics.UpdateTime) <= metricsValidityPeriod
76+
podsWithFreshMetrics := datastore.PodList(func(podInfo podinfo.PodInfo) bool {
77+
if metrics := getMetricsFromPodInfo(podInfo); metrics != nil {
78+
return time.Since(metrics.UpdateTime) <= metricsValidityPeriod
79+
}
80+
return false
7981
})
80-
podsWithStaleMetrics := datastore.PodList(func(s podinfo.PodInfo) bool {
81-
metrics := getMetricsFromPodData(s.GetData())
82-
return time.Since(metrics.UpdateTime) > metricsValidityPeriod
82+
podsWithStaleMetrics := datastore.PodList(func(podInfo podinfo.PodInfo) bool {
83+
if metrics := getMetricsFromPodInfo(podInfo); metrics != nil {
84+
return time.Since(metrics.UpdateTime) > metricsValidityPeriod
85+
}
86+
return false
8387
})
8488
s := fmt.Sprintf("Current Pods and metrics gathered. Fresh metrics: %+v, Stale metrics: %+v", podsWithFreshMetrics, podsWithStaleMetrics)
8589
logger.V(logutil.VERBOSE).Info(s)
@@ -100,29 +104,27 @@ func refreshPrometheusMetrics(logger logr.Logger, datastore Datastore) {
100104
var kvCacheTotal float64
101105
var queueTotal int
102106

103-
podScrapers := datastore.PodGetAll()
104-
logger.V(logutil.TRACE).Info("Refreshing Prometheus Metrics", "ReadyPods", len(podScrapers))
105-
if len(podScrapers) == 0 {
107+
podsInfo := datastore.PodGetAll()
108+
logger.V(logutil.TRACE).Info("Refreshing Prometheus Metrics", "ReadyPods", len(podsInfo))
109+
if len(podsInfo) == 0 {
106110
return
107111
}
108112

109-
for _, pod := range podScrapers {
110-
metrics := getMetricsFromPodData(pod.GetData())
111-
if metrics == nil {
112-
continue
113+
for _, podInfo := range podsInfo {
114+
if metrics := getMetricsFromPodInfo(podInfo); metrics != nil {
115+
kvCacheTotal += metrics.KVCacheUsagePercent
116+
queueTotal += metrics.WaitingQueueSize
113117
}
114-
kvCacheTotal += metrics.KVCacheUsagePercent
115-
queueTotal += metrics.WaitingQueueSize
116118
}
117119

118-
podTotalCount := len(podScrapers)
120+
podTotalCount := len(podsInfo)
119121
metrics.RecordInferencePoolAvgKVCache(pool.Name, kvCacheTotal/float64(podTotalCount))
120122
metrics.RecordInferencePoolAvgQueueSize(pool.Name, float64(queueTotal/podTotalCount))
121123
metrics.RecordinferencePoolReadyPods(pool.Name, float64(podTotalCount))
122124
}
123125

124-
func getMetricsFromPodData(podData map[string]podinfo.ScrapedData) *Metrics {
125-
metrics, ok := podData[MetricsDataKey]
126+
func getMetricsFromPodInfo(podInfo podinfo.PodInfo) *Metrics {
127+
metrics, ok := podInfo.GetData()[MetricsDataKey]
126128
if !ok {
127129
return nil // no entry in the map with metrics key
128130
}

pkg/epp/backend/pod-info/fake_pod_info.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525

2626
var _ PodInfo = &FakePodInfo{}
2727

28-
// FakePodInfo is an implementation of PodScraper that doesn't run the async scrape loop.
28+
// FakePodInfo is an implementation of PodInfo that doesn't run the async scrape loop.
2929
type FakePodInfo struct {
3030
Pod *backend.Pod
3131
Data map[string]ScrapedData

pkg/epp/backend/pod-info/types.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ type ScraperConfig struct {
4242
type Scraper interface {
4343
// Name returns the name of the scraper.
4444
Name() string
45-
// Init returns a empty ScrapeResult that will be stored upon initialization of the PodScraper.
45+
// Init returns a empty ScrapedData that will be stored upon initialization of the Scraper.
4646
// Each Scraper will have it's own data.
4747
InitData() ScrapedData
4848
// Scrape scrapes infromation from a pod.
4949
Scrape(ctx context.Context, pod *backend.Pod, port int) (ScrapedData, error)
5050
// ProcessResult process the returned object from Scrape function.
51-
// This function should update PodScraper data field with the new result.
51+
// This function should update PodInfo data field with the new result.
5252
ProcessResult(ScrapedData)
5353
}
5454

pkg/epp/controller/inferencemodel_reconciler_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ func TestInferenceModelReconciler(t *testing.T) {
195195
WithObjects(initObjs...).
196196
WithIndex(&v1alpha2.InferenceModel{}, datastore.ModelNameIndexKey, indexInferenceModelsByModelName).
197197
Build()
198-
scraperFactory := podinfo.NewPodInfoFactory(map[podinfo.Scraper]*podinfo.ScraperConfig{
198+
podInfoFactory := podinfo.NewPodInfoFactory(map[podinfo.Scraper]*podinfo.ScraperConfig{
199199
&backendmetrics.FakeMetricsScraper{}: podinfo.NewScraperConfig(time.Second, 5*time.Second),
200200
})
201-
ds := datastore.NewDatastore(t.Context(), scraperFactory)
201+
ds := datastore.NewDatastore(t.Context(), podInfoFactory)
202202
for _, m := range test.modelsInStore {
203203
ds.ModelSetIfOlder(m)
204204
}

pkg/epp/controller/inferencepool_reconciler_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import (
3232
"sigs.k8s.io/controller-runtime/pkg/client/fake"
3333
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
3434
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
35-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/scrapers"
35+
podinfo "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/pod-info"
3636
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
3737
utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
3838
)
@@ -94,10 +94,10 @@ func TestInferencePoolReconciler(t *testing.T) {
9494
namespacedName := types.NamespacedName{Name: pool1.Name, Namespace: pool1.Namespace}
9595
req := ctrl.Request{NamespacedName: namespacedName}
9696
ctx := context.Background()
97-
scraperFactory := scrapers.NewPodInfoFactory(map[scrapers.Scraper]*scrapers.ScraperConfig{
98-
&backendmetrics.FakeMetricsScraper{}: scrapers.NewScraperConfig(time.Second, 5*time.Second),
97+
podInfoFactory := podinfo.NewPodInfoFactory(map[podinfo.Scraper]*podinfo.ScraperConfig{
98+
&backendmetrics.FakeMetricsScraper{}: podinfo.NewScraperConfig(time.Second, 5*time.Second),
9999
})
100-
datastore := datastore.NewDatastore(ctx, scraperFactory)
100+
datastore := datastore.NewDatastore(ctx, podInfoFactory)
101101
inferencePoolReconciler := &InferencePoolReconciler{Client: fakeClient, Datastore: datastore}
102102

103103
// Step 1: Inception, only ready pods matching pool1 are added to the store.

pkg/epp/datastore/datastore.go

+15-15
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (ds *datastore) Clear() {
102102
defer ds.poolAndModelsMu.Unlock()
103103
ds.pool = nil
104104
ds.models = make(map[string]*v1alpha2.InferenceModel)
105-
// stop all pods go routines for data collection before clearing the pods map.
105+
// stop all pods go routines before clearing the pods map.
106106
ds.pods.Range(func(_, v any) bool {
107107
v.(podinfo.PodInfo).Stop()
108108
return true
@@ -251,9 +251,9 @@ func (ds *datastore) PodGetAll() []podinfo.PodInfo {
251251
func (ds *datastore) PodList(predicate func(podinfo.PodInfo) bool) []podinfo.PodInfo {
252252
res := []podinfo.PodInfo{}
253253
ds.pods.Range(func(k, v any) bool {
254-
podScraper := v.(podinfo.PodInfo)
255-
if predicate(podScraper) {
256-
res = append(res, podScraper)
254+
podInfo := v.(podinfo.PodInfo)
255+
if predicate(podInfo) {
256+
res = append(res, podInfo)
257257
}
258258
return true
259259
})
@@ -266,24 +266,24 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool {
266266
Name: pod.Name,
267267
Namespace: pod.Namespace,
268268
}
269-
var podScraper podinfo.PodInfo
269+
var podInfo podinfo.PodInfo
270270
existing, ok := ds.pods.Load(namespacedName)
271271
if !ok { // new pod. add a new pod scraper (it's startred internally)
272-
podScraper := ds.podInfoFactory.NewPodInfo(ds.parentCtx, pod, ds)
273-
ds.pods.Store(namespacedName, podScraper)
272+
podInfo = ds.podInfoFactory.NewPodInfo(ds.parentCtx, pod, ds)
273+
ds.pods.Store(namespacedName, podInfo)
274274
} else {
275-
podScraper = existing.(podinfo.PodInfo)
275+
podInfo = existing.(podinfo.PodInfo)
276276
}
277277
// Update pod properties if anything changed.
278-
podScraper.UpdatePod(pod)
278+
podInfo.UpdatePod(pod)
279279
return ok
280280
}
281281

282282
func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
283283
v, ok := ds.pods.LoadAndDelete(namespacedName)
284284
if ok {
285-
podScraper := v.(podinfo.PodInfo)
286-
podScraper.Stop()
285+
podInfo := v.(podinfo.PodInfo)
286+
podInfo.Stop()
287287
}
288288
}
289289

@@ -313,10 +313,10 @@ func (ds *datastore) podResyncAll(ctx context.Context, ctrlClient client.Client)
313313

314314
// Remove pods that don't belong to the pool or not ready any more.
315315
ds.pods.Range(func(k, v any) bool {
316-
s := v.(podinfo.PodInfo)
317-
if exist := activePods[s.GetPod().NamespacedName.Name]; !exist {
318-
logger.V(logutil.VERBOSE).Info("Removing pod", "pod", s.GetPod())
319-
ds.PodDelete(s.GetPod().NamespacedName)
316+
podInfo := v.(podinfo.PodInfo)
317+
if exist := activePods[podInfo.GetPod().NamespacedName.Name]; !exist {
318+
logger.V(logutil.VERBOSE).Info("Removing pod", "pod", podInfo.GetPod())
319+
ds.PodDelete(podInfo.GetPod().NamespacedName)
320320
}
321321
return true
322322
})

pkg/epp/datastore/datastore_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ func TestMetrics(t *testing.T) {
342342
WithScheme(scheme).
343343
Build()
344344
podInfoFactory := podinfo.NewPodInfoFactory(map[podinfo.Scraper]*podinfo.ScraperConfig{
345-
&backendmetrics.FakeMetricsScraper{}: podinfo.NewScraperConfig(time.Second, 5*time.Second),
345+
&backendmetrics.FakeMetricsScraper{}: podinfo.NewScraperConfig(time.Millisecond, 5*time.Second),
346346
})
347347
ds := NewDatastore(ctx, podInfoFactory)
348348
_ = ds.PoolSet(ctx, fakeClient, inferencePool)

pkg/epp/scheduling/plugins/filter/filter.go

+40-31
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ import (
2121
"math/rand"
2222
"time"
2323

24-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2524
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
2625
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
26+
pluginutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/util"
2727
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2828
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2929
)
@@ -139,19 +139,21 @@ func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []ty
139139
filtered := []types.Pod{}
140140

141141
for _, pod := range pods {
142-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
143-
if podMetrics.WaitingQueueSize <= min {
144-
min = podMetrics.WaitingQueueSize
145-
}
146-
if podMetrics.WaitingQueueSize >= max {
147-
max = podMetrics.WaitingQueueSize
142+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
143+
if podMetrics.WaitingQueueSize <= min {
144+
min = podMetrics.WaitingQueueSize
145+
}
146+
if podMetrics.WaitingQueueSize >= max {
147+
max = podMetrics.WaitingQueueSize
148+
}
148149
}
149150
}
150151

151152
for _, pod := range pods {
152-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
153-
if podMetrics.WaitingQueueSize >= min && podMetrics.WaitingQueueSize <= min+(max-min)/len(pods) {
154-
filtered = append(filtered, pod)
153+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
154+
if podMetrics.WaitingQueueSize >= min && podMetrics.WaitingQueueSize <= min+(max-min)/len(pods) {
155+
filtered = append(filtered, pod)
156+
}
155157
}
156158
}
157159
return filtered
@@ -179,19 +181,21 @@ func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []ty
179181
filtered := []types.Pod{}
180182

181183
for _, pod := range pods {
182-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
183-
if podMetrics.KVCacheUsagePercent <= min {
184-
min = podMetrics.KVCacheUsagePercent
185-
}
186-
if podMetrics.KVCacheUsagePercent >= max {
187-
max = podMetrics.KVCacheUsagePercent
184+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
185+
if podMetrics.KVCacheUsagePercent <= min {
186+
min = podMetrics.KVCacheUsagePercent
187+
}
188+
if podMetrics.KVCacheUsagePercent >= max {
189+
max = podMetrics.KVCacheUsagePercent
190+
}
188191
}
189192
}
190193

191194
for _, pod := range pods {
192-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
193-
if podMetrics.KVCacheUsagePercent >= min && podMetrics.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
194-
filtered = append(filtered, pod)
195+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
196+
if podMetrics.KVCacheUsagePercent >= min && podMetrics.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
197+
filtered = append(filtered, pod)
198+
}
195199
}
196200
}
197201
return filtered
@@ -226,14 +230,15 @@ func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod)
226230

227231
// Categorize pods based on affinity and availability
228232
for _, pod := range pods {
229-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
230-
_, active := podMetrics.ActiveModels[ctx.Req.ResolvedTargetModel]
231-
_, waiting := podMetrics.WaitingModels[ctx.Req.ResolvedTargetModel]
232-
233-
if active || waiting {
234-
filtered_affinity = append(filtered_affinity, pod)
235-
} else if len(podMetrics.ActiveModels)+len(podMetrics.WaitingModels) < podMetrics.MaxActiveModels {
236-
filtered_available = append(filtered_available, pod)
233+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
234+
_, active := podMetrics.ActiveModels[ctx.Req.ResolvedTargetModel]
235+
_, waiting := podMetrics.WaitingModels[ctx.Req.ResolvedTargetModel]
236+
237+
if active || waiting {
238+
filtered_affinity = append(filtered_affinity, pod)
239+
} else if len(podMetrics.ActiveModels)+len(podMetrics.WaitingModels) < podMetrics.MaxActiveModels {
240+
filtered_available = append(filtered_available, pod)
241+
}
237242
}
238243
}
239244

@@ -267,15 +272,19 @@ type podPredicate func(req *types.LLMRequest, pod types.Pod) bool
267272

268273
func queueThresholdPredicate(queueThreshold int) podPredicate {
269274
return func(req *types.LLMRequest, pod types.Pod) bool {
270-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
271-
return podMetrics.WaitingQueueSize <= queueThreshold
275+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
276+
return podMetrics.WaitingQueueSize <= queueThreshold
277+
}
278+
return false
272279
}
273280
}
274281

275282
func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate {
276283
return func(req *types.LLMRequest, pod types.Pod) bool {
277-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
278-
return podMetrics.KVCacheUsagePercent <= kvCacheThreshold
284+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
285+
return podMetrics.KVCacheUsagePercent <= kvCacheThreshold
286+
}
287+
return false
279288
}
280289
}
281290

pkg/epp/scheduling/plugins/scorer/kvcache.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
package scorer
1818

1919
import (
20-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
20+
pluginutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/util"
2121
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2222
)
2323

@@ -30,8 +30,12 @@ func (ss *KVCacheScorer) Name() string {
3030
func (ss *KVCacheScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
3131
scores := make(map[types.Pod]float64, len(pods))
3232
for _, pod := range pods {
33-
podMetrics := pod.GetData()[metrics.MetricsDataKey].(*metrics.Metrics)
34-
scores[pod] = 1 - podMetrics.KVCacheUsagePercent
33+
if podMetrics := pluginutil.GetMetricsFromPodInfo(pod); podMetrics != nil {
34+
scores[pod] = 1 - podMetrics.KVCacheUsagePercent
35+
} else {
36+
scores[pod] = 0
37+
}
38+
3539
}
3640
return scores
3741
}

0 commit comments

Comments
 (0)