diff --git a/.gitignore b/.gitignore index 4442b6516..599999112 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,9 @@ go.work.sum # generated docs site + +# tokenizer lib +lib + +# local configuration files +.envrc diff --git a/.golangci.yml b/.golangci.yml index 19139a67a..a42307fce 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,8 +1,6 @@ run: timeout: 5m allow-parallel-runners: true - skip-files: - - "pkg/epp/server/runserver_test.go" # Settings related to issues issues: @@ -21,7 +19,7 @@ linters: - fatcontext - ginkgolinter - gocritic - - govet + # - govet # do not enable - this causes some metalinter issue - loggercheck - misspell - perfsprint @@ -30,17 +28,13 @@ linters: - makezero - errcheck - goconst - - gofmt - - goimports - - gosimple - ineffassign - nakedret - prealloc - - typecheck - unparam - unused - + linters-settings: revive: rules: - - name: comment-spacings + - name: comment-spacings \ No newline at end of file diff --git a/.tekton/buildah-build.yaml b/.tekton/buildah-build.yaml index ad4ab4f40..d680a2333 100644 --- a/.tekton/buildah-build.yaml +++ b/.tekton/buildah-build.yaml @@ -44,6 +44,15 @@ spec: USERNAME=$(jq -r '.auths["quay.io"].username' /root/.docker/config.json) PASSWORD=$(jq -r '.auths["quay.io"].password' /root/.docker/config.json) + echo "🔐 Extracting Git credentials from workspace..." + GIT_USER=$(cat /workspace/git-auth/username) + GIT_TOKEN=$(cat /workspace/git-auth/token) + + if [ -z "$GIT_USER" ] || [ -z "$GIT_TOKEN" ]; then + echo "❌ Error: Missing git-auth credentials" + exit 1 + fi + if [ "$USERNAME" = "null" ] || [ "$PASSWORD" = "null" ]; then echo "❌ Error: Missing registry credentials" exit 1 @@ -56,8 +65,10 @@ spec: export DOCKER_CONFIG=/root/.docker export BUILDER=buildah export IMG=$(params.image_tag_base):$(params.dev-version) - + export GIT_NM_USER=$GIT_USER + export NM_TOKEN=$GIT_TOKEN + echo "🚀 Calling make buildah-build with IMG=$IMG..." - make buildah-build IMG=$IMG + make buildah-build IMG=$IMG echo "$IMG" > /tekton/results/image-url diff --git a/.tekton/go-build-task.yaml b/.tekton/go-build-task.yaml index eeb117976..579d20086 100644 --- a/.tekton/go-build-task.yaml +++ b/.tekton/go-build-task.yaml @@ -12,5 +12,24 @@ spec: script: | #!/bin/bash cd $(workspaces.source.path) + + echo "🔐 Extracting Git credentials from workspace..." + GIT_USER=$(cat /workspace/git-auth/username) + GIT_TOKEN=$(cat /workspace/git-auth/token) + + if [ -z "$GIT_USER" ] || [ -z "$GIT_TOKEN" ]; then + echo "❌ Error: Missing git-auth credentials" + exit 1 + fi + + echo "🔐 Configuring Git..." + git config --global user.email "ci-tag-bot@example.com" + git config --global user.name "ci-tag-bot" + git config --global url."https://${GIT_USER}:${GIT_TOKEN}@github.com".insteadOf "https://github.com" + git config --global --add safe.directory "$(pwd)" + + # required for go build with tokenizer lib linking + dnf install -y gcc-c++ libstdc++ libstdc++-devel && dnf clean all + go env -w GOFLAGS=-buildvcs=false make build diff --git a/.tekton/go-lint-task.yaml b/.tekton/go-lint-task.yaml index f42471a19..809a03223 100644 --- a/.tekton/go-lint-task.yaml +++ b/.tekton/go-lint-task.yaml @@ -11,6 +11,7 @@ spec: steps: - name: run-lint image: us.icr.io/ibm-hc4ai-operator/golangci-lint:v1.64.8 + # image: us.icr.io/ibm-hc4ai-operator/golangci-lint:v2.0.3 imagePullPolicy: IfNotPresent script: | #!/bin/bash diff --git a/.tekton/pipelinerun.yaml b/.tekton/pipelinerun.yaml index 29ef7b666..27cfe5c30 100644 --- a/.tekton/pipelinerun.yaml +++ b/.tekton/pipelinerun.yaml @@ -165,6 +165,9 @@ spec: workspaces: - name: source workspace: source + - name: git-auth + workspace: git-auth + - name: extract-version-and-registry params: @@ -328,6 +331,8 @@ spec: workspace: registry-secret - name: container-storage workspace: container-storage + - name: git-auth + workspace: git-auth - name: vulnerability-scan when: diff --git a/Makefile b/Makefile index bb4f078d9..b51bc16b0 100644 --- a/Makefile +++ b/Makefile @@ -439,11 +439,20 @@ lint: check-golangci-lint ## Run lint golangci-lint run ##@ Build +LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' +CGO_ENABLED=1 # Enable CGO + +.PHONY: download-tokenizer +download-tokenizer: ## Download the HuggingFace tokenizer bindings. + @echo "Downloading HuggingFace tokenizer bindings..." + mkdir -p lib + curl -L https://github.com/daulet/tokenizers/releases/download/v1.20.2/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + ranlib lib/*.a .PHONY: build -build: check-go ## +build: check-go download-tokenizer ## @printf "\033[33;1m==== Building ====\033[0m\n" - go build -o bin/epp cmd/epp/main.go cmd/epp/health.go + go build -ldflags="$(LDFLAGS)" -o bin/epp cmd/epp/main.go cmd/epp/health.go ##@ Container Build/Push @@ -456,7 +465,12 @@ buildah-build: check-builder load-version-json ## Build and push image (multi-ar for arch in amd64; do \ ARCH_TAG=$$FINAL_TAG-$$arch; \ echo "📦 Building for architecture: $$arch"; \ - buildah build --arch=$$arch --os=linux --layers -t $(IMG)-$$arch . || exit 1; \ + buildah build \ + --arch=$$arch \ + --build-arg GIT_NM_USER=$(GIT_NM_USER) \ + --build-arg NM_TOKEN=$(NM_TOKEN) \ + --os=linux \ + --layers -t $(IMG)-$$arch . || exit 1; \ echo "🚀 Pushing image: $(IMG)-$$arch"; \ buildah push $(IMG)-$$arch docker://$(IMG)-$$arch || exit 1; \ done; \ @@ -474,7 +488,11 @@ buildah-build: check-builder load-version-json ## Build and push image (multi-ar sed -e '1 s/\(^FROM\)/FROM --platform=$${BUILDPLATFORM}/' Dockerfile > Dockerfile.cross; \ - docker buildx create --use --name image-builder || true; \ docker buildx use image-builder; \ - docker buildx build --push --platform=$(PLATFORMS) --tag $(IMG) -f Dockerfile.cross . || exit 1; \ + docker buildx build --push \ + --platform=$(PLATFORMS) \ + --build-arg GIT_NM_USER=$(GIT_NM_USER)\ + --build-arg NM_TOKEN=$(NM_TOKEN) \ + --tag $(IMG) -f Dockerfile.cross . || exit 1; \ docker buildx rm image-builder || true; \ rm Dockerfile.cross; \ elif [ "$(BUILDER)" = "podman" ]; then \ @@ -494,6 +512,7 @@ image-build: check-container-tool load-version-json ## Build container image usi --build-arg TARGETARCH=$(TARGETARCH) \ --build-arg GIT_NM_USER=$(GIT_NM_USER)\ --build-arg NM_TOKEN=$(NM_TOKEN) \ + --progress=plain \ -t $(IMG) . .PHONY: image-push diff --git a/README.md b/README.md index 12d4186ee..dd262dcfc 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,11 @@ To enable LoadAwareScorer, the following env vars must be configured: export ENABLE_LOAD_AWARE_SCORER=true export LOAD_AWARE_SCORER_WEIGHT=1.0 ``` + +To enable PDFilter, the following env var must be configured: +``` +export ENABLE_PD_FILTER=true +``` --- [Inference Gateways]:#concepts-and-definitions @@ -96,8 +101,8 @@ See our website at https://gateway-api-inference-extension.sigs.k8s.io/ for deta ## Roadmap As Inference Gateway builds towards a GA release. We will continue to expand our capabilities, namely: -1. Prefix-cache aware load balancing with interfaces for remote caches -1. Recommended LoRA adapter pipeline for automated rollout +1. Prefix-cache aware load balancing with interfaces for remote caches +1. Recommended LoRA adapter pipeline for automated rollout 1. Fairness and priority between workloads within the same criticality band 1. HPA support for autoscaling on aggregate metrics derived from the load balancer 1. Support for large multi-modal inputs and outputs @@ -121,4 +126,3 @@ Contributions are readily welcomed, follow the [dev guide](./docs/dev.md) to sta ### Code of conduct Participation in the Kubernetes community is governed by the [Kubernetes Code of Conduct](code-of-conduct.md). - diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index 901697cb4..5ebf8484e 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -32,7 +32,7 @@ import ( const ( fetchMetricsTimeout = 5 * time.Second - roleLabel = "llmd.org/role" + roleLabel = "llm-d.ai/role" rolePrefill = "prefill" roleDecode = "decode" roleBoth = "both" diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 4997a8b30..881c3653c 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -90,7 +90,7 @@ func (s *StreamingServer) HandleRequestBody( return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} } - res, err := s.scheduler.Schedule(ctx, llmReq) + res, err := s.scheduler.OnRequest(ctx, llmReq) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 6ea7d438c..712b37ad9 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -37,6 +37,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -65,7 +66,8 @@ type StreamingServer struct { } type Scheduler interface { - Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) + OnRequest(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) + OnResponse(ctx context.Context, req *types.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error) } // RequestContext stores context information during the life time of an HTTP request. @@ -189,6 +191,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) case *extProcPb.ProcessingRequest_RequestTrailers: // This is currently unused. case *extProcPb.ProcessingRequest_ResponseHeaders: + responseHeaders := make(map[string]string) for _, header := range v.ResponseHeaders.Headers.GetHeaders() { value := string(header.RawValue) @@ -199,27 +202,53 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx.modelServerStreaming = true loggerTrace.Info("model server is streaming response") } + responseHeaders[header.Key] = value } - reqCtx.RequestState = ResponseRecieved - reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - // This is for debugging purpose only. - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, + llmReq := &schedulingtypes.LLMRequest{ + Model: reqCtx.Model, + Headers: responseHeaders, + ResolvedTargetModel: reqCtx.ResolvedTargetModel, + } + + var result *types.Result + result, err = s.scheduler.OnResponse(ctx, llmReq, reqCtx.TargetPod) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error handling response") + reqCtx.ResponseStatusCode = errutil.ModelServerError + } else { + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + // This is for debugging purpose only. + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + } + + // Add headers added by PostResponse + for key, value := range result.MutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } + + reqCtx.RequestState = ResponseRecieved + reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, }, }, }, }, - }, + } } case *extProcPb.ProcessingRequest_ResponseBody: diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 5c64228ca..3f064fe75 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -26,6 +26,7 @@ type SchedulerConfig struct { scorers map[plugins.Scorer]int // map from scorer to weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } var defPlugin = &defaultPlugin{} @@ -40,4 +41,5 @@ var defaultConfig = &SchedulerConfig{ scorers: map[plugins.Scorer]int{}, picker: defPlugin, postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, } diff --git a/pkg/epp/scheduling/local_config.go b/pkg/epp/scheduling/local_config.go index 2e261a87a..d1df2459c 100644 --- a/pkg/epp/scheduling/local_config.go +++ b/pkg/epp/scheduling/local_config.go @@ -18,7 +18,9 @@ package scheduling import ( "context" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" @@ -28,16 +30,22 @@ import ( const ( kvCacheScorerEnablementEnvVar = "ENABLE_KVCACHE_AWARE_SCORER" loadAwareScorerEnablementEnvVar = "ENABLE_LOAD_AWARE_SCORER" + pdFilterEnablementEnvVar = "ENABLE_PD_FILTER" kvCacheScorerWeightEnvVar = "KVCACHE_AWARE_SCORER_WEIGHT" loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT" ) +func init() { + setDefaultConfig() +} + func setDefaultConfig() { // since the default config is a global variable, we add this function to minimize rebase conflicts. // this configuration is a temporary state, it should be better streamlined. setLoadAwareScorer() setKVCacheAwareScorer() + setPDFilter() defaultConfig.picker = picker.NewMaxScorePicker() } @@ -75,3 +83,16 @@ func setKVCacheAwareScorer() { defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight) } + +func setPDFilter() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + if envutil.GetEnvString(pdFilterEnablementEnvVar, "false", loggerDebug) != "true" { + loggerDebug.Info("Skipping PDFilter creation as it is not enabled") + return + } + + defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter) + loggerDebug.Info("Initialized PDFilter") +} diff --git a/pkg/epp/scheduling/plugins/filter/pd_filter.go b/pkg/epp/scheduling/plugins/filter/pd_filter.go index 945c615d3..228d18143 100644 --- a/pkg/epp/scheduling/plugins/filter/pd_filter.go +++ b/pkg/epp/scheduling/plugins/filter/pd_filter.go @@ -19,8 +19,10 @@ import ( "fmt" "math/rand/v2" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) const ( @@ -42,6 +44,8 @@ var PDFilter = &baseFilter{ // Returns: // - Filtered slice of pod metrics, could contain one or zerro elements func prefillDecodeFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + loggerDebug := log.FromContext(ctx).WithName("pd_filter").V(logutil.DEBUG) + pPods := make([]types.Pod, 0) dPods := make([]types.Pod, 0) @@ -56,7 +60,10 @@ func prefillDecodeFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []t if len(pPods) > 0 { // select a random prefill pod randomIndex := rand.IntN(len(pPods)) - ctx.MutatedHeaders[prefillPodHeader] = fmt.Sprintf("http://%s:%d", pPods[randomIndex].GetPod().Address, ctx.TargetPort) + url := fmt.Sprintf("http://%s:%d", pPods[randomIndex].GetPod().Address, ctx.TargetPort) + loggerDebug.Info("Prefill pod selected", "url", url) + + ctx.MutatedHeaders[prefillPodHeader] = url } if len(dPods) > 1 { diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index f4e1714d4..5a39273f2 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -69,7 +69,6 @@ var ( ) func NewScheduler(datastore Datastore) *Scheduler { - setDefaultConfig() return NewSchedulerWithConfig(datastore, defaultConfig) } @@ -81,6 +80,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched scorers: config.scorers, picker: config.picker, postSchedulePlugins: config.postSchedulePlugins, + postResponsePlugins: config.postResponsePlugins, } } @@ -91,6 +91,7 @@ type Scheduler struct { scorers map[plugins.Scorer]int // map from scorer to its weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } type Datastore interface { @@ -98,8 +99,10 @@ type Datastore interface { PodGetAll() []backendmetrics.PodMetrics } -// Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { +// OnRequest finds the target pod based on metrics and the requested lora adapter. +// OnRequest is invoked during the processing of the request, before it is sent to +// appropriate pod for inference +func (s *Scheduler) OnRequest(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("request", req) loggerDebug := logger.V(logutil.DEBUG) @@ -211,6 +214,42 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty } } +// OnResponse is invoked during the processing of a response from an inference pod. It will invoke +// any defined plugins that process the response. +func (s *Scheduler) OnResponse(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { + pool, err := s.datastore.PoolGet() + if err != nil { + return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + } + + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll()) + var targetPod types.Pod + for _, pod := range pods { + if pod.GetPod().NamespacedName.String() == targetPodName { + targetPod = pod + break + } + } + + sCtx := types.NewSchedulingContext(ctx, req, pods, pool.Spec.TargetPortNumber) + + s.runPostResponsePlugins(sCtx, targetPod) + + return &types.Result{TargetPod: nil, MutatedHeaders: sCtx.MutatedHeaders}, nil +} + +func (s *Scheduler) runPostResponsePlugins(ctx *types.SchedulingContext, targetPod types.Pod) { + for _, plugin := range s.postResponsePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostResponse(ctx, targetPod) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before)) + } +} + type defaultPlugin struct { picker.RandomPicker } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index e6d229aee..4c94fb8d1 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -232,7 +232,7 @@ func TestSchedule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { scheduler := NewScheduler(&fakeDataStore{pods: test.input}) - got, err := scheduler.Schedule(context.Background(), test.req) + got, err := scheduler.OnRequest(context.Background(), test.req) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } @@ -407,7 +407,7 @@ func TestSchedulePlugins(t *testing.T) { Model: "test-model", Headers: test.requestHeaders, } - got, err := scheduler.Schedule(context.Background(), req) + got, err := scheduler.OnRequest(context.Background(), req) // Validate error state if test.err != (err != nil) { @@ -483,6 +483,56 @@ func TestSchedulePlugins(t *testing.T) { } } +func TestPostResponse(t *testing.T) { + pr1 := &testPostResponse{ + NameRes: "pr1", + ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + ReceivedResponseHeaders: make(map[string]string), + } + + tests := []struct { + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + responseHeaders map[string]string + wantMutatedHeaders map[string]string + }{ + { + name: "Simple postResponse test", + config: SchedulerConfig{ + postResponsePlugins: []plugins.PostResponse{pr1}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + }, + responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"}, + wantMutatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + }, + } + + for _, test := range tests { + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) + + req := &types.LLMRequest{ + Model: "test-model", + Headers: test.responseHeaders, + } + + result, err := scheduler.OnResponse(context.Background(), req, test.input[0].Pod.NamespacedName.String()) + if err != nil { + t.Errorf("Received an error. Error: %s", err) + } + + if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" { + t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff) + } + + if diff := cmp.Diff(test.wantMutatedHeaders, result.MutatedHeaders); diff != "" { + t.Errorf("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v", diff) + } + } +} + type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -571,6 +621,23 @@ func (tp *TestPlugin) reset() { tp.NumOfPickerCandidates = 0 } +type testPostResponse struct { + NameRes string + ReceivedResponseHeaders map[string]string + ExtraHeaders map[string]string +} + +func (pr *testPostResponse) Name() string { return pr.NameRes } + +func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { + for key, value := range ctx.Req.Headers { + pr.ReceivedResponseHeaders[key] = value + } + for key, value := range pr.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } +} + func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { res := []types.Pod{} for _, pod := range ctx.PodsSnapshot { diff --git a/pkg/epp/scheduling/scorers_test.go b/pkg/epp/scheduling/scorers_test.go index a98a838b1..d4c3d1e13 100644 --- a/pkg/epp/scheduling/scorers_test.go +++ b/pkg/epp/scheduling/scorers_test.go @@ -86,19 +86,23 @@ func TestScorers(t *testing.T) { }, }, wantRes: &types.Result{ - TargetPod: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + TargetPod: &types.ScoredPod{ + Pod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, + Score: 0.5, }, + MutatedHeaders: map[string]string{}, }, }, } @@ -108,7 +112,7 @@ func TestScorers(t *testing.T) { scheduler := NewScheduler(&fakeDataStore{pods: test.input}) scheduler.scorers = map[plugins.Scorer]int{test.scorer: 1} scheduler.picker = &picker.MaxScorePicker{} - got, err := scheduler.Schedule(context.Background(), test.req) + got, err := scheduler.OnRequest(context.Background(), test.req) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } diff --git a/pkg/epp/server/runserver_test.go b/pkg/epp/server/runserver_test.go index b02688c58..0cb52d6d2 100644 --- a/pkg/epp/server/runserver_test.go +++ b/pkg/epp/server/runserver_test.go @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - package server_test import ( @@ -25,6 +24,9 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +// Define a variable with the manager package type to explicitly show usage to linter +var _ manager.LeaderElectionRunnable = nil + func TestRunnable(t *testing.T) { // Make sure AsRunnable() does not use leader election. runner := server.NewDefaultExtProcServerRunner().AsRunnable(logutil.NewTestLogger())