Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c5bf036

Browse files
authored
[UPSTREAM-SYNC] Added support for mutating headers (#129)
* Added support for plugins to add/modify headers Signed-off-by: Shmuel Kallner <[email protected]> * Added mutating header tests Signed-off-by: Shmuel Kallner <[email protected]> --------- Signed-off-by: Shmuel Kallner <[email protected]>
1 parent 4b42c38 commit c5bf036

File tree

5 files changed

+111
-45
lines changed

5 files changed

+111
-45
lines changed

pkg/epp/handlers/request.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request
106106
reqCtx.TargetPod = targetPod.NamespacedName.String()
107107
reqCtx.TargetEndpoint = endpoint
108108

109-
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes))
109+
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders)
110110

111111
reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{
112112
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
@@ -148,7 +148,7 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
148148
return err
149149
}
150150
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
151-
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
151+
s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil)
152152
return nil
153153
}
154154

pkg/epp/handlers/server.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
375375
return nil
376376
}
377377

378-
func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) {
378+
func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) {
379379
headers := []*configPb.HeaderValueOption{
380380
{
381381
Header: &configPb.HeaderValue{
@@ -394,6 +394,15 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext,
394394
},
395395
})
396396
}
397+
// Add headers added by filters/scorers
398+
for key, value := range mutatedHeaders {
399+
headers = append(headers, &configPb.HeaderValueOption{
400+
Header: &configPb.HeaderValue{
401+
Key: key,
402+
RawValue: []byte(value),
403+
},
404+
})
405+
}
397406

398407
targetEndpointValue := &structpb.Struct{
399408
Fields: map[string]*structpb.Value{

pkg/epp/scheduling/scheduler.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
124124

125125
s.runPostSchedulePlugins(sCtx, result)
126126

127+
result.MutatedHeaders = sCtx.MutatedHeaders
127128
return result, nil
128129
}
129130

pkg/epp/scheduling/scheduler_test.go

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ func TestSchedule(t *testing.T) {
109109
},
110110
},
111111
},
112+
MutatedHeaders: make(map[string]string),
112113
},
113114
},
114115
{
@@ -172,6 +173,7 @@ func TestSchedule(t *testing.T) {
172173
},
173174
},
174175
},
176+
MutatedHeaders: make(map[string]string),
175177
},
176178
},
177179
{
@@ -242,30 +244,41 @@ func TestSchedule(t *testing.T) {
242244

243245
func TestSchedulePlugins(t *testing.T) {
244246
tp1 := &TestPlugin{
245-
NameRes: "test1",
246-
ScoreRes: 0.3,
247-
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}},
247+
NameRes: "test1",
248+
ScoreRes: 0.3,
249+
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}},
250+
ReceivedRequestHeaders: make(map[string]string),
248251
}
249252
tp2 := &TestPlugin{
250-
NameRes: "test2",
251-
ScoreRes: 0.8,
252-
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
253+
NameRes: "test2",
254+
ScoreRes: 0.8,
255+
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
256+
ReceivedRequestHeaders: make(map[string]string),
253257
}
254258
tp_filterAll := &TestPlugin{
255-
NameRes: "filter all",
256-
FilterRes: []k8stypes.NamespacedName{},
259+
NameRes: "filter all",
260+
FilterRes: []k8stypes.NamespacedName{},
261+
ReceivedRequestHeaders: make(map[string]string),
262+
}
263+
tp_headers := &TestPlugin{
264+
NameRes: "headers",
265+
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
266+
ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"},
267+
ReceivedRequestHeaders: make(map[string]string),
257268
}
258269
pickerPlugin := &TestPlugin{
259270
NameRes: "picker",
260271
PickRes: k8stypes.NamespacedName{Name: "pod1"},
261272
}
262273

263274
tests := []struct {
264-
name string
265-
config SchedulerConfig
266-
input []*backendmetrics.FakePodMetrics
267-
wantTargetPod k8stypes.NamespacedName
268-
targetPodScore float64
275+
name string
276+
config SchedulerConfig
277+
input []*backendmetrics.FakePodMetrics
278+
requestHeaders map[string]string
279+
wantTargetPod k8stypes.NamespacedName
280+
wantMutatedHeaders map[string]string
281+
targetPodScore float64
269282
// Number of expected pods to score (after filter)
270283
numPodsToScore int
271284
err bool
@@ -287,10 +300,11 @@ func TestSchedulePlugins(t *testing.T) {
287300
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
288301
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
289302
},
290-
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
291-
targetPodScore: 1.1,
292-
numPodsToScore: 2,
293-
err: false,
303+
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
304+
wantMutatedHeaders: make(map[string]string),
305+
targetPodScore: 1.1,
306+
numPodsToScore: 2,
307+
err: false,
294308
},
295309
{
296310
name: "all plugins executed successfully, different scorers weights",
@@ -309,10 +323,11 @@ func TestSchedulePlugins(t *testing.T) {
309323
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
310324
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
311325
},
312-
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
313-
targetPodScore: 50,
314-
numPodsToScore: 2,
315-
err: false,
326+
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
327+
wantMutatedHeaders: make(map[string]string),
328+
targetPodScore: 50,
329+
numPodsToScore: 2,
330+
err: false,
316331
},
317332
{
318333
name: "filter all",
@@ -334,6 +349,33 @@ func TestSchedulePlugins(t *testing.T) {
334349
numPodsToScore: 0,
335350
err: true, // no available pods to server after filter all
336351
},
352+
{
353+
name: "Mutate a header",
354+
config: SchedulerConfig{
355+
preSchedulePlugins: []plugins.PreSchedule{tp1, tp2},
356+
filters: []plugins.Filter{tp_headers},
357+
scorers: map[plugins.Scorer]int{
358+
tp1: 1,
359+
tp2: 1,
360+
},
361+
picker: pickerPlugin,
362+
postSchedulePlugins: []plugins.PostSchedule{tp1, tp2},
363+
},
364+
input: []*backendmetrics.FakePodMetrics{
365+
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
366+
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
367+
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
368+
},
369+
requestHeaders: map[string]string{
370+
"Content-type": "application/json",
371+
"x-session-id": "qazw-edcr-tgby-nhyu",
372+
},
373+
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
374+
wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"},
375+
targetPodScore: 1.1,
376+
numPodsToScore: 2,
377+
err: false, // no available pods to server after filter all
378+
},
337379
}
338380

339381
for _, test := range tests {
@@ -372,7 +414,10 @@ func TestSchedulePlugins(t *testing.T) {
372414
wantPod := &types.PodMetrics{
373415
Pod: &backend.Pod{NamespacedName: test.wantTargetPod},
374416
}
375-
wantRes := &types.Result{TargetPod: wantPod}
417+
wantRes := &types.Result{
418+
TargetPod: wantPod,
419+
MutatedHeaders: test.wantMutatedHeaders,
420+
}
376421
if diff := cmp.Diff(wantRes, got); diff != "" {
377422
t.Errorf("Unexpected output (-want +got): %v", diff)
378423
}
@@ -437,18 +482,20 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics {
437482

438483
// TestPlugin is an implementation useful in unit tests.
439484
type TestPlugin struct {
440-
NameRes string
441-
ScoreCallCount int
442-
NumOfScoredPods int
443-
ScoreRes float64
444-
FilterCallCount int
445-
FilterRes []k8stypes.NamespacedName
446-
PreScheduleCallCount int
447-
PostScheduleCallCount int
448-
PickCallCount int
449-
NumOfPickerCandidates int
450-
PickRes k8stypes.NamespacedName
451-
WinnderPodScore float64
485+
NameRes string
486+
ScoreCallCount int
487+
NumOfScoredPods int
488+
ScoreRes float64
489+
FilterCallCount int
490+
FilterRes []k8stypes.NamespacedName
491+
PreScheduleCallCount int
492+
PostScheduleCallCount int
493+
PickCallCount int
494+
NumOfPickerCandidates int
495+
PickRes k8stypes.NamespacedName
496+
WinnderPodScore float64
497+
ExtraHeaders map[string]string
498+
ReceivedRequestHeaders map[string]string
452499
}
453500

454501
func (tp *TestPlugin) Name() string { return tp.NameRes }
@@ -459,6 +506,12 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) {
459506

460507
func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
461508
tp.FilterCallCount++
509+
for key, value := range tp.ExtraHeaders {
510+
ctx.MutatedHeaders[key] = value
511+
}
512+
for key, value := range ctx.Req.Headers {
513+
tp.ReceivedRequestHeaders[key] = value
514+
}
462515
return findPods(ctx, tp.FilterRes...)
463516

464517
}

pkg/epp/scheduling/types/types.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ type ScoredPod struct {
5959
// SchedulingContext holds contextual information during a scheduling operation.
6060
type SchedulingContext struct {
6161
context.Context
62-
Logger logr.Logger
63-
Req *LLMRequest
64-
PodsSnapshot []Pod
62+
Logger logr.Logger
63+
Req *LLMRequest
64+
PodsSnapshot []Pod
65+
MutatedHeaders map[string]string
6566
}
6667

6768
func (pm *PodMetrics) String() string {
@@ -87,10 +88,11 @@ type PodMetrics struct {
8788
func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
8889
logger := log.FromContext(ctx).WithValues("request", req)
8990
return &SchedulingContext{
90-
Context: ctx,
91-
Logger: logger,
92-
Req: req,
93-
PodsSnapshot: pods,
91+
Context: ctx,
92+
Logger: logger,
93+
Req: req,
94+
PodsSnapshot: pods,
95+
MutatedHeaders: make(map[string]string),
9496
}
9597
}
9698

@@ -104,5 +106,6 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod {
104106

105107
// Result captures the scheduler result.
106108
type Result struct {
107-
TargetPod Pod
109+
TargetPod Pod
110+
MutatedHeaders map[string]string
108111
}

0 commit comments

Comments
 (0)