Skip to content

[UPSTREAM-SYNC] Added support for mutating headers #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request
reqCtx.TargetPod = targetPod.NamespacedName.String()
reqCtx.TargetEndpoint = endpoint

s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes))
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders)

reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
Expand Down Expand Up @@ -148,7 +148,7 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
return err
}
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil)
return nil
}

Expand Down
11 changes: 10 additions & 1 deletion pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return nil
}

func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) {
func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) {
headers := []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Expand All @@ -394,6 +394,15 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext,
},
})
}
// Add headers added by filters/scorers
for key, value := range mutatedHeaders {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
RawValue: []byte(value),
},
})
}

targetEndpointValue := &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down
1 change: 1 addition & 0 deletions pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types

s.runPostSchedulePlugins(sCtx, result)

result.MutatedHeaders = sCtx.MutatedHeaders
return result, nil
}

Expand Down
121 changes: 87 additions & 34 deletions pkg/epp/scheduling/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func TestSchedule(t *testing.T) {
},
},
},
MutatedHeaders: make(map[string]string),
},
},
{
Expand Down Expand Up @@ -172,6 +173,7 @@ func TestSchedule(t *testing.T) {
},
},
},
MutatedHeaders: make(map[string]string),
},
},
{
Expand Down Expand Up @@ -242,30 +244,41 @@ func TestSchedule(t *testing.T) {

func TestSchedulePlugins(t *testing.T) {
tp1 := &TestPlugin{
NameRes: "test1",
ScoreRes: 0.3,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}},
NameRes: "test1",
ScoreRes: 0.3,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}},
ReceivedRequestHeaders: make(map[string]string),
}
tp2 := &TestPlugin{
NameRes: "test2",
ScoreRes: 0.8,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
NameRes: "test2",
ScoreRes: 0.8,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
ReceivedRequestHeaders: make(map[string]string),
}
tp_filterAll := &TestPlugin{
NameRes: "filter all",
FilterRes: []k8stypes.NamespacedName{},
NameRes: "filter all",
FilterRes: []k8stypes.NamespacedName{},
ReceivedRequestHeaders: make(map[string]string),
}
tp_headers := &TestPlugin{
NameRes: "headers",
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"},
ReceivedRequestHeaders: make(map[string]string),
}
pickerPlugin := &TestPlugin{
NameRes: "picker",
PickRes: k8stypes.NamespacedName{Name: "pod1"},
}

tests := []struct {
name string
config SchedulerConfig
input []*backendmetrics.FakePodMetrics
wantTargetPod k8stypes.NamespacedName
targetPodScore float64
name string
config SchedulerConfig
input []*backendmetrics.FakePodMetrics
requestHeaders map[string]string
wantTargetPod k8stypes.NamespacedName
wantMutatedHeaders map[string]string
targetPodScore float64
// Number of expected pods to score (after filter)
numPodsToScore int
err bool
Expand All @@ -287,10 +300,11 @@ func TestSchedulePlugins(t *testing.T) {
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
targetPodScore: 1.1,
numPodsToScore: 2,
err: false,
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
wantMutatedHeaders: make(map[string]string),
targetPodScore: 1.1,
numPodsToScore: 2,
err: false,
},
{
name: "all plugins executed successfully, different scorers weights",
Expand All @@ -309,10 +323,11 @@ func TestSchedulePlugins(t *testing.T) {
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
targetPodScore: 50,
numPodsToScore: 2,
err: false,
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
wantMutatedHeaders: make(map[string]string),
targetPodScore: 50,
numPodsToScore: 2,
err: false,
},
{
name: "filter all",
Expand All @@ -334,6 +349,33 @@ func TestSchedulePlugins(t *testing.T) {
numPodsToScore: 0,
err: true, // no available pods to server after filter all
},
{
name: "Mutate a header",
config: SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{tp1, tp2},
filters: []plugins.Filter{tp_headers},
scorers: map[plugins.Scorer]int{
tp1: 1,
tp2: 1,
},
picker: pickerPlugin,
postSchedulePlugins: []plugins.PostSchedule{tp1, tp2},
},
input: []*backendmetrics.FakePodMetrics{
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
requestHeaders: map[string]string{
"Content-type": "application/json",
"x-session-id": "qazw-edcr-tgby-nhyu",
},
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"},
targetPodScore: 1.1,
numPodsToScore: 2,
err: false, // no available pods to server after filter all
},
}

for _, test := range tests {
Expand Down Expand Up @@ -372,7 +414,10 @@ func TestSchedulePlugins(t *testing.T) {
wantPod := &types.PodMetrics{
Pod: &backend.Pod{NamespacedName: test.wantTargetPod},
}
wantRes := &types.Result{TargetPod: wantPod}
wantRes := &types.Result{
TargetPod: wantPod,
MutatedHeaders: test.wantMutatedHeaders,
}
if diff := cmp.Diff(wantRes, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
Expand Down Expand Up @@ -437,18 +482,20 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics {

// TestPlugin is an implementation useful in unit tests.
type TestPlugin struct {
NameRes string
ScoreCallCount int
NumOfScoredPods int
ScoreRes float64
FilterCallCount int
FilterRes []k8stypes.NamespacedName
PreScheduleCallCount int
PostScheduleCallCount int
PickCallCount int
NumOfPickerCandidates int
PickRes k8stypes.NamespacedName
WinnderPodScore float64
NameRes string
ScoreCallCount int
NumOfScoredPods int
ScoreRes float64
FilterCallCount int
FilterRes []k8stypes.NamespacedName
PreScheduleCallCount int
PostScheduleCallCount int
PickCallCount int
NumOfPickerCandidates int
PickRes k8stypes.NamespacedName
WinnderPodScore float64
ExtraHeaders map[string]string
ReceivedRequestHeaders map[string]string
}

func (tp *TestPlugin) Name() string { return tp.NameRes }
Expand All @@ -459,6 +506,12 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) {

func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
tp.FilterCallCount++
for key, value := range tp.ExtraHeaders {
ctx.MutatedHeaders[key] = value
}
for key, value := range ctx.Req.Headers {
tp.ReceivedRequestHeaders[key] = value
}
return findPods(ctx, tp.FilterRes...)

}
Expand Down
19 changes: 11 additions & 8 deletions pkg/epp/scheduling/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ type ScoredPod struct {
// SchedulingContext holds contextual information during a scheduling operation.
type SchedulingContext struct {
context.Context
Logger logr.Logger
Req *LLMRequest
PodsSnapshot []Pod
Logger logr.Logger
Req *LLMRequest
PodsSnapshot []Pod
MutatedHeaders map[string]string
}

func (pm *PodMetrics) String() string {
Expand All @@ -87,10 +88,11 @@ type PodMetrics struct {
func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
logger := log.FromContext(ctx).WithValues("request", req)
return &SchedulingContext{
Context: ctx,
Logger: logger,
Req: req,
PodsSnapshot: pods,
Context: ctx,
Logger: logger,
Req: req,
PodsSnapshot: pods,
MutatedHeaders: make(map[string]string),
}
}

Expand All @@ -104,5 +106,6 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod {

// Result captures the scheduler result.
type Result struct {
TargetPod Pod
TargetPod Pod
MutatedHeaders map[string]string
}