From 70f65783c145f4ebd17de07f9eb0e52298a86fcd Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Tue, 6 May 2025 19:04:21 +0300 Subject: [PATCH 1/2] Added support for plugins to add/modify headers Signed-off-by: Shmuel Kallner --- pkg/epp/handlers/request.go | 4 ++-- pkg/epp/handlers/server.go | 11 ++++++++++- pkg/epp/scheduling/scheduler.go | 1 + pkg/epp/scheduling/types/types.go | 19 +++++++++++-------- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 5495c1da..a1151d7d 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -106,7 +106,7 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request reqCtx.TargetPod = targetPod.NamespacedName.String() reqCtx.TargetEndpoint = endpoint - s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes)) + s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders) reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header @@ -148,7 +148,7 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ return err } endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - s.populateRequestHeaderResponse(reqCtx, endpoint, 0) + s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil) return nil } diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 711d8cee..37d84027 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -375,7 +375,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces return nil } -func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) { +func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) { headers := []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ @@ -394,6 +394,15 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, }, }) } + // Add headers added by filters/scorers + for key, value := range mutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } targetEndpointValue := &structpb.Struct{ Fields: map[string]*structpb.Value{ diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 9215489f..78a4f93d 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -124,6 +124,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types s.runPostSchedulePlugins(sCtx, result) + result.MutatedHeaders = sCtx.MutatedHeaders return result, nil } diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 795ef65d..aaefcf5e 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -59,9 +59,10 @@ type ScoredPod struct { // SchedulingContext holds contextual information during a scheduling operation. type SchedulingContext struct { context.Context - Logger logr.Logger - Req *LLMRequest - PodsSnapshot []Pod + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []Pod + MutatedHeaders map[string]string } func (pm *PodMetrics) String() string { @@ -87,10 +88,11 @@ type PodMetrics struct { func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ - Context: ctx, - Logger: logger, - Req: req, - PodsSnapshot: pods, + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + MutatedHeaders: make(map[string]string), } } @@ -104,5 +106,6 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { // Result captures the scheduler result. type Result struct { - TargetPod Pod + TargetPod Pod + MutatedHeaders map[string]string } From c0f590518bbd708f04aa24db05d208091f11d737 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Tue, 6 May 2025 19:04:55 +0300 Subject: [PATCH 2/2] Added mutating header tests Signed-off-by: Shmuel Kallner --- pkg/epp/scheduling/scheduler_test.go | 121 +++++++++++++++++++-------- 1 file changed, 87 insertions(+), 34 deletions(-) diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index b44c7ac2..da2874c0 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -109,6 +109,7 @@ func TestSchedule(t *testing.T) { }, }, }, + MutatedHeaders: make(map[string]string), }, }, { @@ -172,6 +173,7 @@ func TestSchedule(t *testing.T) { }, }, }, + MutatedHeaders: make(map[string]string), }, }, { @@ -242,18 +244,27 @@ func TestSchedule(t *testing.T) { func TestSchedulePlugins(t *testing.T) { tp1 := &TestPlugin{ - NameRes: "test1", - ScoreRes: 0.3, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + ReceivedRequestHeaders: make(map[string]string), } tp2 := &TestPlugin{ - NameRes: "test2", - ScoreRes: 0.8, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + ReceivedRequestHeaders: make(map[string]string), } tp_filterAll := &TestPlugin{ - NameRes: "filter all", - FilterRes: []k8stypes.NamespacedName{}, + NameRes: "filter all", + FilterRes: []k8stypes.NamespacedName{}, + ReceivedRequestHeaders: make(map[string]string), + } + tp_headers := &TestPlugin{ + NameRes: "headers", + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, + ReceivedRequestHeaders: make(map[string]string), } pickerPlugin := &TestPlugin{ NameRes: "picker", @@ -261,11 +272,13 @@ func TestSchedulePlugins(t *testing.T) { } tests := []struct { - name string - config SchedulerConfig - input []*backendmetrics.FakePodMetrics - wantTargetPod k8stypes.NamespacedName - targetPodScore float64 + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + requestHeaders map[string]string + wantTargetPod k8stypes.NamespacedName + wantMutatedHeaders map[string]string + targetPodScore float64 // Number of expected pods to score (after filter) numPodsToScore int err bool @@ -287,10 +300,11 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 1.1, - numPodsToScore: 2, - err: false, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: make(map[string]string), + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, }, { name: "all plugins executed successfully, different scorers weights", @@ -309,10 +323,11 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 50, - numPodsToScore: 2, - err: false, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: make(map[string]string), + targetPodScore: 50, + numPodsToScore: 2, + err: false, }, { name: "filter all", @@ -334,6 +349,33 @@ func TestSchedulePlugins(t *testing.T) { numPodsToScore: 0, err: true, // no available pods to server after filter all }, + { + name: "Mutate a header", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp_headers}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, + picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + requestHeaders: map[string]string{ + "Content-type": "application/json", + "x-session-id": "qazw-edcr-tgby-nhyu", + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, // no available pods to server after filter all + }, } for _, test := range tests { @@ -372,7 +414,10 @@ func TestSchedulePlugins(t *testing.T) { wantPod := &types.PodMetrics{ Pod: &backend.Pod{NamespacedName: test.wantTargetPod}, } - wantRes := &types.Result{TargetPod: wantPod} + wantRes := &types.Result{ + TargetPod: wantPod, + MutatedHeaders: test.wantMutatedHeaders, + } if diff := cmp.Diff(wantRes, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } @@ -437,18 +482,20 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { // TestPlugin is an implementation useful in unit tests. type TestPlugin struct { - NameRes string - ScoreCallCount int - NumOfScoredPods int - ScoreRes float64 - FilterCallCount int - FilterRes []k8stypes.NamespacedName - PreScheduleCallCount int - PostScheduleCallCount int - PickCallCount int - NumOfPickerCandidates int - PickRes k8stypes.NamespacedName - WinnderPodScore float64 + NameRes string + ScoreCallCount int + NumOfScoredPods int + ScoreRes float64 + FilterCallCount int + FilterRes []k8stypes.NamespacedName + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + NumOfPickerCandidates int + PickRes k8stypes.NamespacedName + WinnderPodScore float64 + ExtraHeaders map[string]string + ReceivedRequestHeaders map[string]string } func (tp *TestPlugin) Name() string { return tp.NameRes } @@ -459,6 +506,12 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { tp.FilterCallCount++ + for key, value := range tp.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } + for key, value := range ctx.Req.Headers { + tp.ReceivedRequestHeaders[key] = value + } return findPods(ctx, tp.FilterRes...) }