diff --git a/.gitignore b/.gitignore index bd36bfbb3..4442b6516 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,6 @@ bin/* Dockerfile.cross artifacts -main # Test binary, built with `go test -c` *.test diff --git a/cmd/bbr/health.go b/cmd/bbr/health.go index 4ef8a8b17..7d1b5fd53 100644 --- a/cmd/bbr/health.go +++ b/cmd/bbr/health.go @@ -19,7 +19,6 @@ package main import ( "context" - extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/go-logr/logr" "google.golang.org/grpc/codes" healthPb "google.golang.org/grpc/health/grpc_health_v1" @@ -32,29 +31,10 @@ type healthServer struct { } func (s *healthServer) Check(ctx context.Context, in *healthPb.HealthCheckRequest) (*healthPb.HealthCheckResponse, error) { - if in.Service != extProcPb.ExternalProcessor_ServiceDesc.ServiceName { - s.logger.V(logutil.DEFAULT).Info("gRPC health check requested unknown service", "available-services", []string{extProcPb.ExternalProcessor_ServiceDesc.ServiceName}, "requested-service", in.Service) - return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVICE_UNKNOWN}, nil - } - - s.logger.V(logutil.VERBOSE).Info("gRPC health check serving", "service", extProcPb.ExternalProcessor_ServiceDesc.ServiceName) + s.logger.V(logutil.VERBOSE).Info("gRPC health check serving", "service", in.Service) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVING}, nil } -func (s *healthServer) List(ctx context.Context, _ *healthPb.HealthListRequest) (*healthPb.HealthListResponse, error) { - // currently only the ext_proc service is provided - serviceHealthResponse, err := s.Check(ctx, &healthPb.HealthCheckRequest{Service: extProcPb.ExternalProcessor_ServiceDesc.ServiceName}) - if err != nil { - return nil, err - } - - return &healthPb.HealthListResponse{ - Statuses: map[string]*healthPb.HealthCheckResponse{ - extProcPb.ExternalProcessor_ServiceDesc.ServiceName: serviceHealthResponse, - }, - }, nil -} - func (s *healthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Health_WatchServer) error { return status.Error(codes.Unimplemented, "Watch is not implemented") } diff --git a/cmd/epp/health.go b/cmd/epp/health.go index d5727d654..936970021 100644 --- a/cmd/epp/health.go +++ b/cmd/epp/health.go @@ -19,7 +19,6 @@ package main import ( "context" - extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/go-logr/logr" "google.golang.org/grpc/codes" healthPb "google.golang.org/grpc/health/grpc_health_v1" @@ -34,34 +33,14 @@ type healthServer struct { } func (s *healthServer) Check(ctx context.Context, in *healthPb.HealthCheckRequest) (*healthPb.HealthCheckResponse, error) { - if in.Service != extProcPb.ExternalProcessor_ServiceDesc.ServiceName { - s.logger.V(logutil.DEFAULT).Info("gRPC health check requested unknown service", "available-services", []string{extProcPb.ExternalProcessor_ServiceDesc.ServiceName}, "requested-service", in.Service) - return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVICE_UNKNOWN}, nil - } - if !s.datastore.PoolHasSynced() { - s.logger.V(logutil.DEFAULT).Info("gRPC health check not serving", "service", extProcPb.ExternalProcessor_ServiceDesc.ServiceName) + s.logger.V(logutil.DEFAULT).Info("gRPC health check not serving", "service", in.Service) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_NOT_SERVING}, nil } - - s.logger.V(logutil.TRACE).Info("gRPC health check serving", "service", extProcPb.ExternalProcessor_ServiceDesc.ServiceName) + s.logger.V(logutil.TRACE).Info("gRPC health check serving", "service", in.Service) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVING}, nil } -func (s *healthServer) List(ctx context.Context, _ *healthPb.HealthListRequest) (*healthPb.HealthListResponse, error) { - // currently only the ext_proc service is provided - serviceHealthResponse, err := s.Check(ctx, &healthPb.HealthCheckRequest{Service: extProcPb.ExternalProcessor_ServiceDesc.ServiceName}) - if err != nil { - return nil, err - } - - return &healthPb.HealthListResponse{ - Statuses: map[string]*healthPb.HealthCheckResponse{ - extProcPb.ExternalProcessor_ServiceDesc.ServiceName: serviceHealthResponse, - }, - }, nil -} - func (s *healthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Health_WatchServer) error { return status.Error(codes.Unimplemented, "Watch is not implemented") } diff --git a/config/charts/inferencepool/templates/epp-deployment.yaml b/config/charts/inferencepool/templates/epp-deployment.yaml index c391d06ac..fc490210a 100644 --- a/config/charts/inferencepool/templates/epp-deployment.yaml +++ b/config/charts/inferencepool/templates/epp-deployment.yaml @@ -53,12 +53,12 @@ spec: livenessProbe: grpc: port: 9003 - service: envoy.service.ext_proc.v3.ExternalProcessor + service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 readinessProbe: grpc: port: 9003 - service: envoy.service.ext_proc.v3.ExternalProcessor + service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 diff --git a/config/manifests/benchmark/benchmark.yaml b/config/manifests/benchmark/benchmark.yaml index c784730e8..abf9ae5f6 100644 --- a/config/manifests/benchmark/benchmark.yaml +++ b/config/manifests/benchmark/benchmark.yaml @@ -37,7 +37,7 @@ spec: - name: BACKEND value: vllm - name: PORT - value: "8081" + value: "80" - name: INPUT_LENGTH value: "1024" - name: OUTPUT_LENGTH diff --git a/config/manifests/inferencepool-resources.yaml b/config/manifests/inferencepool-resources.yaml index e1b359e6d..37099c43a 100644 --- a/config/manifests/inferencepool-resources.yaml +++ b/config/manifests/inferencepool-resources.yaml @@ -71,13 +71,13 @@ spec: livenessProbe: grpc: port: 9003 - service: envoy.service.ext_proc.v3.ExternalProcessor + service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 readinessProbe: grpc: port: 9003 - service: envoy.service.ext_proc.v3.ExternalProcessor + service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 --- diff --git a/go.mod b/go.mod index 9469d0d38..9888e7be9 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/stretchr/testify v1.10.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 - google.golang.org/grpc v1.72.0 + google.golang.org/grpc v1.71.1 google.golang.org/protobuf v1.36.6 k8s.io/api v0.32.4 k8s.io/apiextensions-apiserver v0.32.4 @@ -25,13 +25,13 @@ require ( k8s.io/component-base v0.32.4 k8s.io/utils v0.0.0-20241210054802-24370beab758 sigs.k8s.io/controller-runtime v0.20.4 - sigs.k8s.io/gateway-api v1.2.1 - sigs.k8s.io/structured-merge-diff/v4 v4.6.0 + sigs.k8s.io/gateway-api v1.3.0 + sigs.k8s.io/structured-merge-diff/v4 v4.7.0 sigs.k8s.io/yaml v1.4.0 ) require ( - cel.dev/expr v0.20.0 // indirect + cel.dev/expr v0.19.1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible // indirect @@ -41,12 +41,12 @@ require ( github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 // indirect + github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.12.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect - github.com/fatih/color v1.17.0 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -58,7 +58,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/gobuffalo/flect v1.0.2 // indirect + github.com/gobuffalo/flect v1.0.3 // indirect github.com/goccy/go-yaml v1.11.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect @@ -91,8 +91,8 @@ require ( github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/spf13/cobra v1.8.1 // indirect - github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/cobra v1.9.1 // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect github.com/x448/float16 v0.8.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect @@ -109,7 +109,7 @@ require ( golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.39.0 // indirect - golang.org/x/oauth2 v0.26.0 // indirect + golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect golang.org/x/term v0.31.0 // indirect @@ -118,8 +118,8 @@ require ( golang.org/x/tools v0.31.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect @@ -129,6 +129,6 @@ require ( k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.0 // indirect - sigs.k8s.io/controller-tools v0.16.3 // indirect + sigs.k8s.io/controller-tools v0.17.3 // indirect sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect ) diff --git a/go.sum b/go.sum index a47a9f905..508f9ea50 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.20.0 h1:OunBvVCfvpWlt4dN7zg3FM6TDkzOePe1+foGJ9AXeeI= -cel.dev/expr v0.20.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= +cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= @@ -20,9 +20,9 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 h1:Om6kYQYDUk5wWbT0t0q6pvyM49i9XZAv9dDrkDA7gjk= -github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 h1:boJj011Hh+874zpIySeApCX4GeOjPl9qhRF3QuIZq+Q= +github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -35,12 +35,12 @@ github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8k github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= -github.com/evanphx/json-patch v5.7.0+incompatible h1:vgGkfT/9f8zE6tvSCe74nfpAVDQ2tG6yudJd8LBksgI= -github.com/evanphx/json-patch v5.7.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= github.com/evanphx/json-patch/v5 v5.9.11/go.mod h1:3j+LviiESTElxA4p3EMKAB9HXj3/XEtnUf6OZxqIQTM= -github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= -github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -69,8 +69,8 @@ github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7a github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/gobuffalo/flect v1.0.2 h1:eqjPGSo2WmjgY2XlpGwo2NXgL3RucAKo4k4qQMNA5sA= -github.com/gobuffalo/flect v1.0.2/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= +github.com/gobuffalo/flect v1.0.3 h1:xeWBM2nui+qnVvNM4S3foBhCAL2XgPU+a7FdpelbTq4= +github.com/gobuffalo/flect v1.0.3/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= github.com/goccy/go-yaml v1.11.3 h1:B3W9IdWbvrUu2OYQGwvU1nZtvMQJPBKgBUuweJjLj6I= github.com/goccy/go-yaml v1.11.3/go.mod h1:wKnAMd44+9JAAnGQpWVEgBzGt3YuTaQ4uXoHvE4m7WU= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -169,10 +169,10 @@ github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoG github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= -github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stoewer/go-strcase v1.3.0 h1:g0eASXYtp+yvN9fK8sH94oCIk0fau9uV1/ZdJ0AVEzs= github.com/stoewer/go-strcase v1.3.0/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -234,8 +234,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= -golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= -golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -271,12 +271,12 @@ golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSm golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= -google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a h1:nwKuGPlUAt+aR+pcrkfFRrTU1BVrSmYyYMxYbUIVHr0= -google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a/go.mod h1:3kWAYMk1I75K4vykHtKt2ycnOgpA6974V7bREqbsenU= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 h1:h6p3mQqrmT1XkHVTfzLdNz1u7IhINeZkz67/xTbOuWs= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= -google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= +google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= +google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= +google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= +google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -319,15 +319,15 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.0 h1:CPT0ExVicCzcp sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.0/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= sigs.k8s.io/controller-runtime v0.20.4 h1:X3c+Odnxz+iPTRobG4tp092+CvBU9UK0t/bRf+n0DGU= sigs.k8s.io/controller-runtime v0.20.4/go.mod h1:xg2XB0K5ShQzAgsoujxuKN4LNXR2LfwwHsPj7Iaw+XY= -sigs.k8s.io/controller-tools v0.16.3 h1:z48C5/d4jCVQQvtiSBL5MYyZ3EO2eFIOXrIKMgHVhFY= -sigs.k8s.io/controller-tools v0.16.3/go.mod h1:AEj6k+w1kYpLZv2einOH3mj52ips4W/6FUjnB5tkJGs= -sigs.k8s.io/gateway-api v1.2.1 h1:fZZ/+RyRb+Y5tGkwxFKuYuSRQHu9dZtbjenblleOLHM= -sigs.k8s.io/gateway-api v1.2.1/go.mod h1:EpNfEXNjiYfUJypf0eZ0P5iXA9ekSGWaS1WgPaM42X0= +sigs.k8s.io/controller-tools v0.17.3 h1:lwFPLicpBKLgIepah+c8ikRBubFW5kOQyT88r3EwfNw= +sigs.k8s.io/controller-tools v0.17.3/go.mod h1:1ii+oXcYZkxcBXzwv3YZBlzjt1fvkrCGjVF73blosJI= +sigs.k8s.io/gateway-api v1.3.0 h1:q6okN+/UKDATola4JY7zXzx40WO4VISk7i9DIfOvr9M= +sigs.k8s.io/gateway-api v1.3.0/go.mod h1:d8NV8nJbaRbEKem+5IuxkL8gJGOZ+FJ+NvOIltV8gDk= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016 h1:kXv6kKdoEtedwuqMmkqhbkgvYKeycVbC8+iPCP9j5kQ= sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= -sigs.k8s.io/structured-merge-diff/v4 v4.6.0 h1:IUA9nvMmnKWcj5jl84xn+T5MnlZKThmUW1TdblaLVAc= -sigs.k8s.io/structured-merge-diff/v4 v4.6.0/go.mod h1:dDy58f92j70zLsuZVuUX5Wp9vtxXpaZnkPGWeqDfCps= +sigs.k8s.io/structured-merge-diff/v4 v4.7.0 h1:qPeWmscJcXP0snki5IYF79Z8xrl8ETFxgMd7wez1XkI= +sigs.k8s.io/structured-merge-diff/v4 v4.7.0/go.mod h1:dDy58f92j70zLsuZVuUX5Wp9vtxXpaZnkPGWeqDfCps= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index a1151d7d3..1df05ce5d 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -18,100 +18,50 @@ package handlers import ( "context" - "encoding/json" - "fmt" "strconv" "time" + configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "google.golang.org/protobuf/types/known/structpb" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling. -func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { - logger := log.FromContext(ctx) - - var requestBodyBytes []byte - requestBodyMap := reqCtx.Request.Body - // Resolve target models. - model, ok := requestBodyMap["model"].(string) - if !ok { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} - } - prompt, ok := requestBodyMap["prompt"].(string) - if !ok { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} - } - - modelName := model +func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error { + reqCtx.RequestReceivedTimestamp = time.Now() - // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. - // This might be a security risk in the future where adapters not registered in the InferenceModel - // are able to be requested by using their distinct name. - modelObj := s.datastore.ModelGet(model) - if modelObj == nil { - return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} - } - if len(modelObj.Spec.TargetModels) > 0 { - modelName = RandomWeightedDraw(logger, modelObj, 0) - if modelName == "" { - return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} + // an EoS in the request headers means this request has no body or trailers. + if req.RequestHeaders.EndOfStream { + // We will route this request to a random pod as this is assumed to just be a GET + // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526 + // The above PR will address endpoint admission, but currently any request without a body will be + // routed to a random upstream pod. + pod := s.director.GetRandomPod() + if pod == nil { + return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"} } + pool, err := s.datastore.PoolGet() + if err != nil { + return err + } + reqCtx.TargetEndpoint = pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) + reqCtx.RequestSize = 0 + reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx) + return nil } - llmReq := &schedulingtypes.LLMRequest{ - Model: model, - ResolvedTargetModel: modelName, - Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, - Prompt: prompt, - Headers: reqCtx.Request.Headers, - } - logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) - - var err error - // Update target models in the body. - if llmReq.Model != llmReq.ResolvedTargetModel { - requestBodyMap["model"] = llmReq.ResolvedTargetModel - } - - requestBodyBytes, err = json.Marshal(requestBodyMap) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body") - return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} - } - - res, err := s.scheduler.Schedule(ctx, llmReq) - if err != nil { - return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} - } - targetPod := res.TargetPod.GetPod() - // Insert target endpoint to instruct Envoy to route requests to the specified target pod. - // Attach the port number - pool, err := s.datastore.PoolGet() - if err != nil { - return reqCtx, err + for _, header := range req.RequestHeaders.Headers.Headers { + if header.RawValue != nil { + reqCtx.Request.Headers[header.Key] = string(header.RawValue) + } else { + reqCtx.Request.Headers[header.Key] = header.Value + } } - endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - - logger.V(logutil.DEFAULT).Info("Request handled", - "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) - - reqCtx.Model = llmReq.Model - reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel - reqCtx.RequestSize = len(requestBodyBytes) - reqCtx.TargetPod = targetPod.NamespacedName.String() - reqCtx.TargetEndpoint = endpoint - - s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders) + return nil +} - reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ - // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header - // and as an unstructure ext-proc response metadata key/value pair. This enables different integration - // options for gateway providers. +func (s *StreamingServer) generateRequestBodyResponse(requestBodyBytes []byte) *extProcPb.ProcessingResponse { + return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_RequestBody{ RequestBody: &extProcPb.BodyResponse{ Response: &extProcPb.CommonResponse{ @@ -127,37 +77,82 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request }, }, } - return reqCtx, nil } -func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error { - reqCtx.RequestReceivedTimestamp = time.Now() +func (s *StreamingServer) generateRequestHeaderResponse(reqCtx *RequestContext) *extProcPb.ProcessingResponse { + // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header + // and as an unstructure ext-proc response metadata key/value pair. This enables different integration + // options for gateway providers. + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: s.generateHeaders(reqCtx), + }, + }, + }, + }, + DynamicMetadata: s.generateMetadata(reqCtx.TargetEndpoint), + } +} - // an EoS in the request headers means this request has no body or trailers. - if req.RequestHeaders.EndOfStream { - // We will route this request to a random pod as this is assumed to just be a GET - // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526 - // The above PR will address endpoint admission, but currently any request without a body will be - // routed to a random upstream pod. - pod := GetRandomPod(s.datastore) - if pod == nil { - return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"} - } - pool, err := s.datastore.PoolGet() - if err != nil { - return err - } - endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil) - return nil +func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.HeaderValueOption { + // can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic. + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: s.destinationEndpointHintKey, + RawValue: []byte(reqCtx.TargetEndpoint), + }, + }, + } + if reqCtx.RequestSize > 0 { + // We need to update the content length header if the body is mutated, see Envoy doc: + // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(reqCtx.RequestSize)), + }, + }) } - for _, header := range req.RequestHeaders.Headers.Headers { - if header.RawValue != nil { - reqCtx.Request.Headers[header.Key] = string(header.RawValue) - } else { - reqCtx.Request.Headers[header.Key] = header.Value + // include all headers + for key, value := range reqCtx.Request.Headers { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } + return headers +} + +func (s *StreamingServer) generateMetadata(endpoint string) *structpb.Struct { + targetEndpointValue := &structpb.Struct{ + Fields: map[string]*structpb.Value{ + s.destinationEndpointHintKey: { + Kind: &structpb.Value_StringValue{ + StringValue: endpoint, + }, + }, + }, + } + dynamicMetadata := targetEndpointValue + if s.destinationEndpointHintMetadataNamespace != "" { + // If a namespace is defined, wrap the selected endpoint with that. + dynamicMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + s.destinationEndpointHintMetadataNamespace: { + Kind: &structpb.Value_StructValue{ + StructValue: targetEndpointValue, + }, + }, + }, } } - return nil + return dynamicMetadata } diff --git a/pkg/epp/handlers/request_test.go b/pkg/epp/handlers/request_test.go index 8675b64cb..f4f2eb136 100644 --- a/pkg/epp/handlers/request_test.go +++ b/pkg/epp/handlers/request_test.go @@ -16,169 +16,7 @@ limitations under the License. package handlers -import ( - "context" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client/fake" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" - testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" -) - const ( DefaultDestinationEndpointHintMetadataNamespace = "envoy.lb" // default for --destinationEndpointHintMetadataNamespace DefaultDestinationEndpointHintKey = "x-gateway-destination-endpoint" // default for --destinationEndpointHintKey ) - -func TestHandleRequestBody(t *testing.T) { - ctx := logutil.NewTestLoggerIntoContext(context.Background()) - - // Setup datastore - tsModel := "food-review" - modelWithTarget := "food-review-0" - model1 := testutil.MakeInferenceModel("model1"). - CreationTimestamp(metav1.Unix(1000, 0)). - ModelName(tsModel).ObjRef() - model2 := testutil.MakeInferenceModel("model2"). - CreationTimestamp(metav1.Unix(1000, 0)). - ModelName(modelWithTarget).ObjRef() - pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - ds := datastore.NewDatastore(t.Context(), pmf) - ds.ModelSetIfOlder(model1) - ds.ModelSetIfOlder(model2) - - pool := &v1alpha2.InferencePool{ - Spec: v1alpha2.InferencePoolSpec{ - TargetPortNumber: int32(8000), - Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ - "some-key": "some-val", - }, - }, - } - pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}} - scheme := runtime.NewScheme() - _ = clientgoscheme.AddToScheme(scheme) - fakeClient := fake.NewClientBuilder(). - WithScheme(scheme). - Build() - if err := ds.PoolSet(ctx, fakeClient, pool); err != nil { - t.Error(err, "Error while setting inference pool") - } - ds.PodUpdateOrAddIfNotExist(pod) - - tests := []struct { - name string - reqBodyMap map[string]interface{} - wantErrCode string - wantReqCtx *RequestContext - wantRespBody map[string]interface{} - }{ - { - name: "successful request", - reqBodyMap: map[string]interface{}{ - "model": tsModel, - "prompt": "test prompt", - }, - wantReqCtx: &RequestContext{ - Model: tsModel, - ResolvedTargetModel: tsModel, - TargetPod: "/pod1", - TargetEndpoint: "address-1:8000", - }, - wantRespBody: map[string]interface{}{ - "model": tsModel, - "prompt": "test prompt", - }, - }, - { - name: "successful request with target model", - reqBodyMap: map[string]interface{}{ - "model": modelWithTarget, - "prompt": "test prompt", - }, - wantReqCtx: &RequestContext{ - Model: modelWithTarget, - ResolvedTargetModel: modelWithTarget, - TargetPod: "/pod1", - TargetEndpoint: "address-1:8000", - }, - wantRespBody: map[string]interface{}{ - "model": modelWithTarget, - "prompt": "test prompt", - }, - }, - { - name: "no model defined, expect err", - wantErrCode: errutil.BadRequest, - }, - { - name: "invalid model defined, expect err", - reqBodyMap: map[string]interface{}{ - "model": "non-existent-model", - "prompt": "test prompt", - }, - wantErrCode: errutil.BadConfiguration, - }, - { - name: "invalid target defined, expect err", - reqBodyMap: map[string]interface{}{ - "model": "food-review-1", - "prompt": "test prompt", - }, - wantErrCode: errutil.BadConfiguration, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - server := NewStreamingServer(scheduling.NewScheduler(ds), DefaultDestinationEndpointHintMetadataNamespace, DefaultDestinationEndpointHintKey, ds) - reqCtx := &RequestContext{ - Request: &Request{ - Body: test.reqBodyMap, - }, - } - reqCtx, err := server.HandleRequestBody(ctx, reqCtx) - - if test.wantErrCode != "" { - if err == nil { - t.Fatalf("HandleRequestBody should have returned an error containing '%s', but got nil", test.wantErrCode) - } - if !strings.Contains(err.Error(), test.wantErrCode) { - t.Fatalf("HandleRequestBody returned error '%v', which does not contain expected substring '%s'", err, test.wantErrCode) - } - return - } - - if err != nil { - t.Fatalf("HandleRequestBody returned unexpected error: %v", err) - } - - if test.wantReqCtx != nil { - if diff := cmp.Diff(test.wantReqCtx.Model, reqCtx.Model); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.Model, diff(-want, +got): %v", diff) - } - if diff := cmp.Diff(test.wantReqCtx.ResolvedTargetModel, reqCtx.ResolvedTargetModel); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.ResolvedTargetModel, diff(-want, +got): %v", diff) - } - if diff := cmp.Diff(test.wantReqCtx.TargetPod, reqCtx.TargetPod); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetPod, diff(-want, +got): %v", diff) - } - if diff := cmp.Diff(test.wantReqCtx.TargetEndpoint, reqCtx.TargetEndpoint); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetEndpoint, diff(-want, +got): %v", diff) - } - } - }) - } -} diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 77f5b163a..ae9c6f4be 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -20,8 +20,6 @@ import ( "context" "encoding/json" "io" - "math/rand" - "strconv" "strings" "time" @@ -31,46 +29,50 @@ import ( "github.com/go-logr/logr" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) -func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore datastore.Datastore) *StreamingServer { +func NewStreamingServer(destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore Datastore, director Director) *StreamingServer { return &StreamingServer{ - scheduler: scheduler, destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace, destinationEndpointHintKey: destinationEndpointHintKey, + director: director, datastore: datastore, } } +type Director interface { + HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) + GetRandomPod() *backend.Pod +} + +type Datastore interface { + PoolGet() (*v1alpha2.InferencePool, error) +} + // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto type StreamingServer struct { - scheduler Scheduler // The key of the header to specify the target pod address. This value needs to match Envoy // configuration. destinationEndpointHintKey string // The key acting as the outer namespace struct in the metadata extproc response to communicate // back the picked endpoints. destinationEndpointHintMetadataNamespace string - datastore datastore.Datastore -} - -type Scheduler interface { - Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) - ScheduleWithContext(ctx *schedulingtypes.SchedulingContext, req *schedulingtypes.LLMRequest) (*schedulingtypes.Result, error) + datastore Datastore + director Director } // RequestContext stores context information during the life time of an HTTP request. +// TODO: The requestContext is gathering a ton of fields. A future refactor needs to tease these fields apart. +// Specifically, there are fields related to the ext-proc protocol, and then fields related to the lifecycle of the request. +// We should split these apart as this monolithic object exposes too much data to too many layers. type RequestContext struct { TargetPod string TargetEndpoint string @@ -193,13 +195,24 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) // Body stream complete. Allocate empty slice for response to use. body = []byte{} - reqCtx, err = s.HandleRequestBody(ctx, reqCtx) + reqCtx, err = s.director.HandleRequest(ctx, reqCtx) if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error handling body") - } else { - metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel) - metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize) + logger.V(logutil.DEFAULT).Error(err, "Error handling request") + break } + + // Populate the ExtProc protocol responses for the request body. + requestBodyBytes, err := json.Marshal(reqCtx.Request.Body) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error marshalling request body") + break + } + reqCtx.RequestSize = len(requestBodyBytes) + reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx) + reqCtx.reqBodyResp = s.generateRequestBodyResponse(requestBodyBytes) + + metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel) + metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize) } case *extProcPb.ProcessingRequest_RequestTrailers: // This is currently unused. @@ -376,114 +389,6 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces return nil } -func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) { - headers := []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: s.destinationEndpointHintKey, - RawValue: []byte(endpoint), - }, - }, - } - if requestBodyLength > 0 { - // We need to update the content length header if the body is mutated, see Envoy doc: - // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto - headers = append(headers, &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte(strconv.Itoa(requestBodyLength)), - }, - }) - } - // Add headers added by filters/scorers - for key, value := range mutatedHeaders { - headers = append(headers, &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: key, - RawValue: []byte(value), - }, - }) - } - - targetEndpointValue := &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintKey: { - Kind: &structpb.Value_StringValue{ - StringValue: endpoint, - }, - }, - }, - } - dynamicMetadata := targetEndpointValue - if s.destinationEndpointHintMetadataNamespace != "" { - // If a namespace is defined, wrap the selected endpoint with that. - dynamicMetadata = &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintMetadataNamespace: { - Kind: &structpb.Value_StructValue{ - StructValue: targetEndpointValue, - }, - }, - }, - } - } - - reqCtx.reqHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - ClearRouteCache: true, - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: headers, - }, - }, - }, - }, - DynamicMetadata: dynamicMetadata, - } -} - -func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { - // TODO: after we are down to 1 server implementation, make these methods a part of the struct - // and handle random seeding on the struct. - source := rand.NewSource(rand.Int63()) - if seed > 0 { - source = rand.NewSource(seed) - } - r := rand.New(source) - - // all the weight values are nil, then we should return random model name - if model.Spec.TargetModels[0].Weight == nil { - index := r.Int31n(int32(len(model.Spec.TargetModels))) - return model.Spec.TargetModels[index].Name - } - - var weights int32 - for _, model := range model.Spec.TargetModels { - weights += *model.Weight - } - logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) - randomVal := r.Int31n(weights) - // TODO: optimize this without using loop - for _, model := range model.Spec.TargetModels { - if randomVal < *model.Weight { - return model.Name - } - randomVal -= *model.Weight - } - return "" -} - -func GetRandomPod(ds datastore.Datastore) *backend.Pod { - pods := ds.PodGetAll() - if len(pods) == 0 { - return nil - } - number := rand.Intn(len(pods)) - pod := pods[number] - return pod.GetPod() -} - func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) { var resp *extProcPb.ProcessingResponse diff --git a/pkg/epp/handlers/server_test.go b/pkg/epp/handlers/server_test.go deleted file mode 100644 index 23d2b68fa..000000000 --- a/pkg/epp/handlers/server_test.go +++ /dev/null @@ -1,186 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package handlers - -import ( - "testing" - "time" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -func TestRandomWeightedDraw(t *testing.T) { - logger := logutil.NewTestLogger() - tests := []struct { - name string - model *v1alpha2.InferenceModel - want string - }{ - { - name: "'random' distribution", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - Weight: pointer(50), - }, - { - Name: "v1", - Weight: pointer(50), - }, - }, - }, - }, - want: "canary", - }, - { - name: "'random' distribution", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - Weight: pointer(25), - }, - { - Name: "v1.1", - Weight: pointer(55), - }, - { - Name: "v1", - Weight: pointer(50), - }, - }, - }, - }, - want: "v1", - }, - { - name: "'random' distribution", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - Weight: pointer(20), - }, - { - Name: "v1.1", - Weight: pointer(20), - }, - { - Name: "v1", - Weight: pointer(10), - }, - }, - }, - }, - want: "v1.1", - }, - { - name: "weighted distribution with weight unset", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - }, - { - Name: "v1.1", - }, - { - Name: "v1", - }, - }, - }, - }, - want: "canary", - }, - } - var seedVal int64 = 420 - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for range 10000 { - model := RandomWeightedDraw(logger, test.model, seedVal) - if model != test.want { - t.Errorf("Model returned: %v != %v", model, test.want) - break - } - } - }) - } -} - -func TestGetRandomPod(t *testing.T) { - tests := []struct { - name string - storePods []*corev1.Pod - expectNil bool - }{ - { - name: "No pods available", - storePods: []*corev1.Pod{}, - expectNil: true, - }, - { - name: "Single pod available", - storePods: []*corev1.Pod{ - {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, - }, - expectNil: false, - }, - { - name: "Multiple pods available", - storePods: []*corev1.Pod{ - {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, - {ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}, - {ObjectMeta: metav1.ObjectMeta{Name: "pod3"}}, - }, - expectNil: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - pmf := metrics.NewPodMetricsFactory(&metrics.FakePodMetricsClient{}, time.Millisecond) - ds := datastore.NewDatastore(t.Context(), pmf) - for _, pod := range test.storePods { - ds.PodUpdateOrAddIfNotExist(pod) - } - - gotPod := GetRandomPod(ds) - - if test.expectNil && gotPod != nil { - t.Errorf("expected nil pod, got: %v", gotPod) - } - if !test.expectNil && gotPod == nil { - t.Errorf("expected non-nil pod, got nil") - } - }) - } -} - -func pointer(v int32) *int32 { - return &v -} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go new file mode 100644 index 000000000..cafcc80b0 --- /dev/null +++ b/pkg/epp/requestcontrol/director.go @@ -0,0 +1,184 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import ( + "context" + "fmt" + "math/rand" + "strconv" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type Scheduler interface { + Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) +} + +type Director struct { + datastore datastore.Datastore + scheduler Scheduler +} + +func NewDirector(datastore datastore.Datastore, scheduler Scheduler) *Director { + return &Director{ + datastore: datastore, + scheduler: scheduler, + } +} + +// HandleRequest always returns the requestContext even in the error case, as the request context is used in error handling. +func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx) + + // Resolve target models. + var ok bool + requestBodyMap := reqCtx.Request.Body + reqCtx.Model, ok = requestBodyMap["model"].(string) + if !ok { + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} + } + prompt, ok := requestBodyMap["prompt"].(string) + if !ok { + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} + } + + // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. + // This might be a security risk in the future where adapters not registered in the InferenceModel + // are able to be requested by using their distinct name. + modelObj := d.datastore.ModelGet(reqCtx.Model) + if modelObj == nil { + return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", reqCtx.Model)} + } + + reqCtx.ResolvedTargetModel = reqCtx.Model + if len(modelObj.Spec.TargetModels) > 0 { + reqCtx.ResolvedTargetModel = RandomWeightedDraw(logger, modelObj, 0) + if reqCtx.ResolvedTargetModel == "" { + return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} + } + } + + llmReq := &schedulingtypes.LLMRequest{ + Model: reqCtx.Model, + ResolvedTargetModel: reqCtx.ResolvedTargetModel, + Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, + Prompt: prompt, + Headers: reqCtx.Request.Headers, + } + logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) + results, err := d.Dispatch(ctx, llmReq) + if err != nil { + return reqCtx, err + } + + // Insert target endpoint to instruct Envoy to route requests to the specified target pod. + // Attach the port number + reqCtx, err = d.PostDispatch(ctx, reqCtx, results) + if err != nil { + return reqCtx, err + } + + return reqCtx, nil +} + +// Dispatch runs one or many scheduling cycles. +func (d *Director) Dispatch(ctx context.Context, llmReq *schedulingtypes.LLMRequest) ([]*schedulingtypes.Result, error) { + var err error + res, err := d.scheduler.Schedule(ctx, llmReq) + if err != nil { + return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} + } + + return []*schedulingtypes.Result{res}, nil +} + +func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestContext, results []*schedulingtypes.Result) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx) + // currently only get a single result. Will refactor to pluggably implement the PostSchedule + if len(results) == 0 { + return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"} + } + targetPod := results[0].TargetPod.GetPod() + + pool, err := d.datastore.PoolGet() + if err != nil { + return reqCtx, err + } + + endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) + logger.V(logutil.DEFAULT).Info("Request handled", + "model", reqCtx.Model, "targetModel", reqCtx.ResolvedTargetModel, "endpoint", targetPod) + + // Update target models in the body. + if reqCtx.Model != reqCtx.ResolvedTargetModel { + reqCtx.Request.Body["model"] = reqCtx.ResolvedTargetModel + } + reqCtx.TargetPod = targetPod.NamespacedName.String() + reqCtx.TargetEndpoint = endpoint + + return reqCtx, nil +} + +func (d *Director) GetRandomPod() *backend.Pod { + pods := d.datastore.PodGetAll() + if len(pods) == 0 { + return nil + } + number := rand.Intn(len(pods)) + pod := pods[number] + return pod.GetPod() +} + +func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { + // TODO: after we are down to 1 server implementation, make these methods a part of the struct + // and handle random seeding on the struct. + source := rand.NewSource(rand.Int63()) + if seed > 0 { + source = rand.NewSource(seed) + } + r := rand.New(source) + + // all the weight values are nil, then we should return random model name + if model.Spec.TargetModels[0].Weight == nil { + index := r.Int31n(int32(len(model.Spec.TargetModels))) + return model.Spec.TargetModels[index].Name + } + + var weights int32 + for _, model := range model.Spec.TargetModels { + weights += *model.Weight + } + logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) + randomVal := r.Int31n(weights) + // TODO: optimize this without using loop + for _, model := range model.Spec.TargetModels { + if randomVal < *model.Weight { + return model.Name + } + randomVal -= *model.Weight + } + return "" +} diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go new file mode 100644 index 000000000..05dc1b3b8 --- /dev/null +++ b/pkg/epp/requestcontrol/director_test.go @@ -0,0 +1,339 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" +) + +func TestHandleRequest(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + // Setup datastore + tsModel := "food-review" + modelWithTarget := "food-review-0" + model1 := testutil.MakeInferenceModel("model1"). + CreationTimestamp(metav1.Unix(1000, 0)). + ModelName(tsModel).ObjRef() + model2 := testutil.MakeInferenceModel("model2"). + CreationTimestamp(metav1.Unix(1000, 0)). + ModelName(modelWithTarget).ObjRef() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := datastore.NewDatastore(t.Context(), pmf) + ds.ModelSetIfOlder(model1) + ds.ModelSetIfOlder(model2) + + pool := &v1alpha2.InferencePool{ + Spec: v1alpha2.InferencePoolSpec{ + TargetPortNumber: int32(8000), + Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ + "some-key": "some-val", + }, + }, + } + pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}} + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() + if err := ds.PoolSet(ctx, fakeClient, pool); err != nil { + t.Error(err, "Error while setting inference pool") + } + ds.PodUpdateOrAddIfNotExist(pod) + + tests := []struct { + name string + reqBodyMap map[string]interface{} + wantErrCode string + wantReqCtx *handlers.RequestContext + wantRespBody map[string]interface{} + }{ + { + name: "successful request", + reqBodyMap: map[string]interface{}{ + "model": tsModel, + "prompt": "test prompt", + }, + wantReqCtx: &handlers.RequestContext{ + Model: tsModel, + ResolvedTargetModel: tsModel, + TargetPod: "/pod1", + TargetEndpoint: "address-1:8000", + }, + wantRespBody: map[string]interface{}{ + "model": tsModel, + "prompt": "test prompt", + }, + }, + { + name: "successful request with target model", + reqBodyMap: map[string]interface{}{ + "model": modelWithTarget, + "prompt": "test prompt", + }, + wantReqCtx: &handlers.RequestContext{ + Model: modelWithTarget, + ResolvedTargetModel: modelWithTarget, + TargetPod: "/pod1", + TargetEndpoint: "address-1:8000", + }, + wantRespBody: map[string]interface{}{ + "model": modelWithTarget, + "prompt": "test prompt", + }, + }, + { + name: "no model defined, expect err", + wantErrCode: errutil.BadRequest, + }, + { + name: "invalid model defined, expect err", + reqBodyMap: map[string]interface{}{ + "model": "non-existent-model", + "prompt": "test prompt", + }, + wantErrCode: errutil.BadConfiguration, + }, + { + name: "invalid target defined, expect err", + reqBodyMap: map[string]interface{}{ + "model": "food-review-1", + "prompt": "test prompt", + }, + wantErrCode: errutil.BadConfiguration, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := NewDirector(ds, scheduling.NewScheduler(ds)) + reqCtx := &handlers.RequestContext{ + Request: &handlers.Request{ + Body: test.reqBodyMap, + }, + } + reqCtx, err := server.HandleRequest(ctx, reqCtx) + + if test.wantErrCode != "" { + if err == nil { + t.Fatalf("HandleRequestBody should have returned an error containing '%s', but got nil", test.wantErrCode) + } + if !strings.Contains(err.Error(), test.wantErrCode) { + t.Fatalf("HandleRequestBody returned error '%v', which does not contain expected substring '%s'", err, test.wantErrCode) + } + return + } + + if err != nil { + t.Fatalf("HandleRequestBody returned unexpected error: %v", err) + } + + if test.wantReqCtx != nil { + if diff := cmp.Diff(test.wantReqCtx.Model, reqCtx.Model); diff != "" { + t.Errorf("HandleRequestBody returned unexpected reqCtx.Model, diff(-want, +got): %v", diff) + } + if diff := cmp.Diff(test.wantReqCtx.ResolvedTargetModel, reqCtx.ResolvedTargetModel); diff != "" { + t.Errorf("HandleRequestBody returned unexpected reqCtx.ResolvedTargetModel, diff(-want, +got): %v", diff) + } + if diff := cmp.Diff(test.wantReqCtx.TargetPod, reqCtx.TargetPod); diff != "" { + t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetPod, diff(-want, +got): %v", diff) + } + if diff := cmp.Diff(test.wantReqCtx.TargetEndpoint, reqCtx.TargetEndpoint); diff != "" { + t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetEndpoint, diff(-want, +got): %v", diff) + } + } + }) + } +} + +func TestRandomWeightedDraw(t *testing.T) { + logger := logutil.NewTestLogger() + tests := []struct { + name string + model *v1alpha2.InferenceModel + want string + }{ + { + name: "'random' distribution", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + Weight: pointer(50), + }, + { + Name: "v1", + Weight: pointer(50), + }, + }, + }, + }, + want: "canary", + }, + { + name: "'random' distribution", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + Weight: pointer(25), + }, + { + Name: "v1.1", + Weight: pointer(55), + }, + { + Name: "v1", + Weight: pointer(50), + }, + }, + }, + }, + want: "v1", + }, + { + name: "'random' distribution", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + Weight: pointer(20), + }, + { + Name: "v1.1", + Weight: pointer(20), + }, + { + Name: "v1", + Weight: pointer(10), + }, + }, + }, + }, + want: "v1.1", + }, + { + name: "weighted distribution with weight unset", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + }, + { + Name: "v1.1", + }, + { + Name: "v1", + }, + }, + }, + }, + want: "canary", + }, + } + var seedVal int64 = 420 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for range 10000 { + model := RandomWeightedDraw(logger, test.model, seedVal) + if model != test.want { + t.Errorf("Model returned: %v != %v", model, test.want) + break + } + } + }) + } +} + +func TestGetRandomPod(t *testing.T) { + tests := []struct { + name string + storePods []*corev1.Pod + expectNil bool + }{ + { + name: "No pods available", + storePods: []*corev1.Pod{}, + expectNil: true, + }, + { + name: "Single pod available", + storePods: []*corev1.Pod{ + {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, + }, + expectNil: false, + }, + { + name: "Multiple pods available", + storePods: []*corev1.Pod{ + {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "pod3"}}, + }, + expectNil: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pmf := metrics.NewPodMetricsFactory(&metrics.FakePodMetricsClient{}, time.Millisecond) + ds := datastore.NewDatastore(t.Context(), pmf) + for _, pod := range test.storePods { + ds.PodUpdateOrAddIfNotExist(pod) + } + d := &Director{datastore: ds} + gotPod := d.GetRandomPod() + + if test.expectNil && gotPod != nil { + t.Errorf("expected nil pod, got: %v", gotPod) + } + if !test.expectNil && gotPod == nil { + t.Errorf("expected non-nil pod, got nil") + } + }) + } +} + +func pointer(v int32) *int32 { + return &v +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 1e8ea330f..940299feb 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -128,7 +128,6 @@ func (s *Scheduler) ScheduleWithContext(sCtx *types.SchedulingContext, req *type s.runPostSchedulePlugins(sCtx, result) - result.MutatedHeaders = sCtx.MutatedHeaders return result, nil } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index da2874c06..b44c7ac2e 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -109,7 +109,6 @@ func TestSchedule(t *testing.T) { }, }, }, - MutatedHeaders: make(map[string]string), }, }, { @@ -173,7 +172,6 @@ func TestSchedule(t *testing.T) { }, }, }, - MutatedHeaders: make(map[string]string), }, }, { @@ -244,27 +242,18 @@ func TestSchedule(t *testing.T) { func TestSchedulePlugins(t *testing.T) { tp1 := &TestPlugin{ - NameRes: "test1", - ScoreRes: 0.3, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, - ReceivedRequestHeaders: make(map[string]string), + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, } tp2 := &TestPlugin{ - NameRes: "test2", - ScoreRes: 0.8, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, - ReceivedRequestHeaders: make(map[string]string), + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, } tp_filterAll := &TestPlugin{ - NameRes: "filter all", - FilterRes: []k8stypes.NamespacedName{}, - ReceivedRequestHeaders: make(map[string]string), - } - tp_headers := &TestPlugin{ - NameRes: "headers", - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, - ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, - ReceivedRequestHeaders: make(map[string]string), + NameRes: "filter all", + FilterRes: []k8stypes.NamespacedName{}, } pickerPlugin := &TestPlugin{ NameRes: "picker", @@ -272,13 +261,11 @@ func TestSchedulePlugins(t *testing.T) { } tests := []struct { - name string - config SchedulerConfig - input []*backendmetrics.FakePodMetrics - requestHeaders map[string]string - wantTargetPod k8stypes.NamespacedName - wantMutatedHeaders map[string]string - targetPodScore float64 + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + wantTargetPod k8stypes.NamespacedName + targetPodScore float64 // Number of expected pods to score (after filter) numPodsToScore int err bool @@ -300,11 +287,10 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - wantMutatedHeaders: make(map[string]string), - targetPodScore: 1.1, - numPodsToScore: 2, - err: false, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, }, { name: "all plugins executed successfully, different scorers weights", @@ -323,11 +309,10 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - wantMutatedHeaders: make(map[string]string), - targetPodScore: 50, - numPodsToScore: 2, - err: false, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 50, + numPodsToScore: 2, + err: false, }, { name: "filter all", @@ -349,33 +334,6 @@ func TestSchedulePlugins(t *testing.T) { numPodsToScore: 0, err: true, // no available pods to server after filter all }, - { - name: "Mutate a header", - config: SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, - filters: []plugins.Filter{tp_headers}, - scorers: map[plugins.Scorer]int{ - tp1: 1, - tp2: 1, - }, - picker: pickerPlugin, - postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, - }, - input: []*backendmetrics.FakePodMetrics{ - {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, - }, - requestHeaders: map[string]string{ - "Content-type": "application/json", - "x-session-id": "qazw-edcr-tgby-nhyu", - }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, - targetPodScore: 1.1, - numPodsToScore: 2, - err: false, // no available pods to server after filter all - }, } for _, test := range tests { @@ -414,10 +372,7 @@ func TestSchedulePlugins(t *testing.T) { wantPod := &types.PodMetrics{ Pod: &backend.Pod{NamespacedName: test.wantTargetPod}, } - wantRes := &types.Result{ - TargetPod: wantPod, - MutatedHeaders: test.wantMutatedHeaders, - } + wantRes := &types.Result{TargetPod: wantPod} if diff := cmp.Diff(wantRes, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } @@ -482,20 +437,18 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { // TestPlugin is an implementation useful in unit tests. type TestPlugin struct { - NameRes string - ScoreCallCount int - NumOfScoredPods int - ScoreRes float64 - FilterCallCount int - FilterRes []k8stypes.NamespacedName - PreScheduleCallCount int - PostScheduleCallCount int - PickCallCount int - NumOfPickerCandidates int - PickRes k8stypes.NamespacedName - WinnderPodScore float64 - ExtraHeaders map[string]string - ReceivedRequestHeaders map[string]string + NameRes string + ScoreCallCount int + NumOfScoredPods int + ScoreRes float64 + FilterCallCount int + FilterRes []k8stypes.NamespacedName + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + NumOfPickerCandidates int + PickRes k8stypes.NamespacedName + WinnderPodScore float64 } func (tp *TestPlugin) Name() string { return tp.NameRes } @@ -506,12 +459,6 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { tp.FilterCallCount++ - for key, value := range tp.ExtraHeaders { - ctx.MutatedHeaders[key] = value - } - for key, value := range ctx.Req.Headers { - tp.ReceivedRequestHeaders[key] = value - } return findPods(ctx, tp.FilterRes...) } diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index aaefcf5ee..795ef65d2 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -59,10 +59,9 @@ type ScoredPod struct { // SchedulingContext holds contextual information during a scheduling operation. type SchedulingContext struct { context.Context - Logger logr.Logger - Req *LLMRequest - PodsSnapshot []Pod - MutatedHeaders map[string]string + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []Pod } func (pm *PodMetrics) String() string { @@ -88,11 +87,10 @@ type PodMetrics struct { func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ - Context: ctx, - Logger: logger, - Req: req, - PodsSnapshot: pods, - MutatedHeaders: make(map[string]string), + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, } } @@ -106,6 +104,5 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { // Result captures the scheduler result. type Result struct { - TargetPod Pod - MutatedHeaders map[string]string + TargetPod Pod } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 8abb3b150..4b8620826 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -35,6 +35,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/controller" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" ) // ExtProcServerRunner provides methods to manage an external process server. @@ -48,7 +49,7 @@ type ExtProcServerRunner struct { CertPath string UseStreaming bool RefreshPrometheusMetricsInterval time.Duration - Scheduler handlers.Scheduler + Scheduler requestcontrol.Scheduler // This should only be used in tests. We won't need this once we don't inject metrics in the tests. // TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup @@ -137,7 +138,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { } else { srv = grpc.NewServer() } - extProcServer := handlers.NewStreamingServer(r.Scheduler, r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) + extProcServer := handlers.NewStreamingServer(r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore, requestcontrol.NewDirector(r.Datastore, r.Scheduler)) extProcPb.RegisterExternalProcessorServer( srv, extProcServer, diff --git a/site-src/guides/index.md b/site-src/guides/index.md index be1b972f8..89811263f 100644 --- a/site-src/guides/index.md +++ b/site-src/guides/index.md @@ -76,7 +76,7 @@ This quickstart guide is intended for engineers familiar with k8s and model serv kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/inferencemodel.yaml ``` -### Deploy the InferencePool and Extension +### Deploy the InferencePool and Endpoint Picker Extension ```bash kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/inferencepool-resources.yaml diff --git a/site-src/performance/benchmark/index.md b/site-src/performance/benchmark/index.md index 39457bf66..160cc26fb 100644 --- a/site-src/performance/benchmark/index.md +++ b/site-src/performance/benchmark/index.md @@ -1,45 +1,49 @@ # Benchmark -This user guide shows how to run benchmarks against a vLLM deployment, by using both the Gateway API -inference extension, and a Kubernetes service as the load balancing strategy. The -benchmark uses the [Latency Profile Generator](https://github.com/AI-Hypercomputer/inference-benchmark) (LPG) -tool to generate load and collect results. +This user guide shows how to run benchmarks against a vLLM model server deployment by using both Gateway API +Inference Extension, and a Kubernetes service as the load balancing strategy. The benchmark uses the +[Latency Profile Generator](https://github.com/AI-Hypercomputer/inference-benchmark) (LPG) tool to generate +load and collect results. ## Prerequisites ### Deploy the inference extension and sample model server -Follow this user guide https://gateway-api-inference-extension.sigs.k8s.io/guides/ to deploy the -sample vLLM application, and the inference extension. +Follow the [getting started guide](https://gateway-api-inference-extension.sigs.k8s.io/guides/#getting-started-with-gateway-api-inference-extension) +to deploy the vLLM model server, CRDs, etc. + +__Note:__ Only the GPU-based model server deployment option is supported for benchmark testing. ### [Optional] Scale the sample vLLM deployment -You will more likely to see the benefits of the inference extension when there are a decent number of replicas to make the optimal routing decision. +You are more likely to see the benefits of the inference extension when there are a decent number of replicas to make the optimal routing decision. ```bash -kubectl scale --replicas=8 -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/vllm/gpu-deployment.yaml +kubectl scale deployment vllm-llama3-8b-instruct --replicas=8 ``` ### Expose the model server via a k8s service -As the baseline, let's also expose the vLLM deployment as a k8s service: +To establish a baseline, expose the vLLM deployment as a k8s service: ```bash -kubectl expose -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/vllm/gpu-deployment.yaml --port=8081 --target-port=8000 --type=LoadBalancer +kubectl expose deployment vllm-llama3-8b-instruct --port=80 --target-port=8000 --type=LoadBalancer ``` ## Run benchmark -The LPG benchmark tool works by sending traffic to the specified target IP and port, and collect results. Follow the steps below to run a single benchmark. You can deploy multiple LPG instances if you want to run benchmarks in parallel against different targets. +The LPG benchmark tool works by sending traffic to the specified target IP and port, and collecting the results. +Follow the steps below to run a single benchmark. Multiple LPG instances can be deployed to run benchmarks in +parallel against different targets. 1. Check out the repo. - + ```bash git clone https://github.com/kubernetes-sigs/gateway-api-inference-extension cd gateway-api-inference-extension ``` -1. Get the target IP. Examples below show how to get the IP of a gateway or a LoadBalancer k8s service. +1. Get the target IP. The examples below shows how to get the IP of a gateway or a k8s service. ```bash # Get gateway IP @@ -51,32 +55,43 @@ The LPG benchmark tool works by sending traffic to the specified target IP and p echo $SVC_IP ``` -1. Then update the `` in `./config/manifests/benchmark/benchmark.yaml` to your target IP. Feel free to adjust other parameters such as request_rates as well. For a complete list of LPG configurations, pls refer to the [LPG user guide](https://github.com/AI-Hypercomputer/inference-benchmark?tab=readme-ov-file#configuring-the-benchmark). +1. Then update the `` in `./config/manifests/benchmark/benchmark.yaml` to the value of `$SVC_IP` or `$GW_IP`. + Feel free to adjust other parameters such as `request_rates` as well. For a complete list of LPG configurations, refer to the + [LPG user guide](https://github.com/AI-Hypercomputer/inference-benchmark?tab=readme-ov-file#configuring-the-benchmark). -1. Start the benchmark tool. `kubectl apply -f ./config/manifests/benchmark/benchmark.yaml` +1. Start the benchmark tool. -1. Wait for benchmark to finish and download the results. Use the `benchmark_id` environment variable -to specify what this benchmark is for. For instance, `inference-extension` or `k8s-svc`. When the LPG tool finishes benchmarking, it will print a log line `LPG_FINISHED`, -the script below will watch for that log line and then start downloading results. + ```bash + kubectl apply -f ./config/manifests/benchmark/benchmark.yaml + ``` + +1. Wait for benchmark to finish and download the results. Use the `benchmark_id` environment variable to specify what this + benchmark is for. For instance, `inference-extension` or `k8s-svc`. When the LPG tool finishes benchmarking, it will print + a log line `LPG_FINISHED`. The script below will watch for that log line and then start downloading results. ```bash - benchmark_id='my-benchmark' ./tools/benchmark/download-benchmark-results.bash + benchmark_id='k8s-svc' ./tools/benchmark/download-benchmark-results.bash ``` - 1. After the script finishes, you should see benchmark results under `./tools/benchmark/output/default-run/my-benchmark/results/json` folder. Here is a [sample json file](./sample.json). + + After the script finishes, you should see benchmark results under `./tools/benchmark/output/default-run/k8s-svc/results/json` folder. + Here is a [sample json file](./sample.json). Replace `k8s-svc` with `inference-extension` when running an inference extension benchmark. ### Tips +* When using a `benchmark_id` other than `k8s-svc` or `inference-extension`, the labels in `./tools/benchmark/benchmark.ipynb` must be + updated accordingly to analyze the results. * You can specify `run_id="runX"` environment variable when running the `./download-benchmark-results.bash` script. This is useful when you run benchmarks multiple times to get a more statistically meaningful results and group the results accordingly. * Update the `request_rates` that best suit your benchmark environment. ### Advanced Benchmark Configurations -Pls refer to the [LPG user guide](https://github.com/AI-Hypercomputer/inference-benchmark?tab=readme-ov-file#configuring-the-benchmark) for a detailed list of configuration knobs. +Refer to the [LPG user guide](https://github.com/AI-Hypercomputer/inference-benchmark?tab=readme-ov-file#configuring-the-benchmark) for a +detailed list of configuration knobs. ## Analyze the results -This guide shows how to run the jupyter notebook using vscode. +This guide shows how to run the jupyter notebook using vscode after completing k8s service and inference extension benchmarks. 1. Create a python virtual environment. @@ -92,6 +107,6 @@ This guide shows how to run the jupyter notebook using vscode. ``` 1. Open the notebook `./tools/benchmark/benchmark.ipynb`, and run each cell. At the end you should - see a bar chart like below where **"ie"** represents inference extension. This chart is generated using this benchmarking tool with 6 vLLM (v1) model servers (H100 80 GB), [llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) and the [ShareGPT dataset](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json). - - ![alt text](example-bar-chart.png) \ No newline at end of file + see a bar chart like below where __"ie"__ represents inference extension. This chart is generated using this benchmarking tool with 6 vLLM (v1) model servers (H100 80 GB), [llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) and the [ShareGPT dataset](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json). + + ![alt text](example-bar-chart.png) diff --git a/test/e2e/epp/e2e_suite_test.go b/test/e2e/epp/e2e_suite_test.go index 01ed639d7..65959234e 100644 --- a/test/e2e/epp/e2e_suite_test.go +++ b/test/e2e/epp/e2e_suite_test.go @@ -62,6 +62,8 @@ const ( modelServerName = "vllm-llama3-8b-instruct" // modelName is the test model name. modelName = "food-review" + // targetModelName is the target model name of the test model server. + targetModelName = modelName + "-1" // envoyName is the name of the envoy proxy test resources. envoyName = "envoy" // envoyPort is the listener port number of the test envoy proxy. diff --git a/test/e2e/epp/e2e_test.go b/test/e2e/epp/e2e_test.go index 7240cebc4..6fce63a39 100644 --- a/test/e2e/epp/e2e_test.go +++ b/test/e2e/epp/e2e_test.go @@ -96,12 +96,8 @@ var _ = ginkgo.Describe("InferencePool", func() { func newInferenceModel(ns string) *v1alpha2.InferenceModel { targets := []v1alpha2.TargetModel{ { - Name: modelName, - Weight: ptr.To(int32(50)), - }, - { - Name: "cad-fabricator", - Weight: ptr.To(int32(50)), + Name: targetModelName, + Weight: ptr.To(int32(100)), }, } return testutils.MakeModelWrapper("inferencemodel-sample", ns). diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 68787da72..244fadf9c 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -155,6 +155,12 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { RawValue: []byte(strconv.Itoa(76)), }, }, + { + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, }}, }, }, @@ -239,6 +245,12 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { RawValue: []byte(strconv.Itoa(76)), }, }, + { + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, }}, }, }, @@ -323,6 +335,12 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { RawValue: []byte(strconv.Itoa(76)), }, }, + { + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, }}, }, }, @@ -456,6 +474,12 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { RawValue: []byte(strconv.Itoa(76)), }, }, + { + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, }}, }, }, @@ -567,6 +591,12 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { RawValue: []byte(strconv.Itoa(76)), }, }, + { + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, }}, }, }, @@ -678,6 +708,12 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { RawValue: []byte(strconv.Itoa(74)), }, }, + { + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, }}, }, }, diff --git a/test/testdata/inferencepool-e2e.yaml b/test/testdata/inferencepool-e2e.yaml index 41bc1e1de..0cd89346b 100644 --- a/test/testdata/inferencepool-e2e.yaml +++ b/test/testdata/inferencepool-e2e.yaml @@ -73,13 +73,13 @@ spec: livenessProbe: grpc: port: 9003 - service: envoy.service.ext_proc.v3.ExternalProcessor + service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 readinessProbe: grpc: port: 9003 - service: envoy.service.ext_proc.v3.ExternalProcessor + service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 --- diff --git a/tools/benchmark/benchmark.ipynb b/tools/benchmark/benchmark.ipynb index 993279cb9..ffd4c455e 100644 --- a/tools/benchmark/benchmark.ipynb +++ b/tools/benchmark/benchmark.ipynb @@ -21,7 +21,7 @@ "#@title Configuration. Edit this before running the rest.\n", "\n", "OUTPUT_DIR='output'\n", - "RUN_ID='example-run'\n", + "RUN_ID='default-run'\n", "# Path to the benchmark dir under `gateway-api-inference-extension/benchmark`\n", "BENCHMARK_DIR =\"./\"\n", "# A regex to match the model name, which matches the output file name.\n", @@ -229,7 +229,7 @@ " plot_func(curAx, m)\n", " return fig, axes\n", "\n", - "def plot_bar(labels, groups, metrics=CORE_METRICS, num_plots_per_row=NUM_PLOTS_PER_ROW, interactive=INTERACTIVE_PLOT, annotate=False):\n", + "def plot_bar(labels, groups, metrics=CORE_METRICS, num_plots_per_row=NUM_PLOTS_PER_ROW, interactive=False, annotate=False):\n", " labels = [label.alias for label in labels]\n", " logger.debug(f'Prnting bar chart for {labels}')\n", " logger.debug(f'groups: {groups}')\n",