diff --git a/cmd/bbr/runner/runner.go b/cmd/bbr/runner/runner.go index 8a4b4291f..8529be79f 100644 --- a/cmd/bbr/runner/runner.go +++ b/cmd/bbr/runner/runner.go @@ -34,6 +34,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/server" + bbrutils "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/utils" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -102,8 +103,24 @@ func (r *Runner) Run(ctx context.Context) error { return err } + //Initialize PluginRegistry and request/response PluginsChain instances + registry, requestChain, responseChain, err := bbrutils.InitPlugins() + if err != nil { + setupLog.Error(err, "Failed to initialize plugins") + return err + } + + setupLog.Info("BBR started with:", + "registry", registry, + "requestChain", requestChain, + "responseChain", responseChain) + // Setup runner. - serverRunner := runserver.NewDefaultExtProcServerRunner(*grpcPort, *streaming) + serverRunner := runserver.NewDefaultExtProcServerRunner(*grpcPort, + *streaming, + registry, + requestChain, + responseChain) // Register health server. if err := registerHealthServer(mgr, *grpcHealthPort); err != nil { diff --git a/pkg/bbr/README.md b/pkg/bbr/README.md index 80ab38354..87106ba06 100644 --- a/pkg/bbr/README.md +++ b/pkg/bbr/README.md @@ -1,6 +1,6 @@ # Body-Based Routing -This package provides an extension that can be deployed to write the `model` -HTTP body parameter as a header (X-Gateway-Model-Name) so as to enable routing capabilities on the +By deafult this package provides a plugable extension that can be to set the `model` +HTTP body parameter as a header (`X-Gateway-Model-Name`) so as to enable routing capabilities on the model name. As per OpenAI spec, it is standard for the model name to be included in the diff --git a/pkg/bbr/framework/interfaces.go b/pkg/bbr/framework/interfaces.go new file mode 100644 index 000000000..1c8fc30b9 --- /dev/null +++ b/pkg/bbr/framework/interfaces.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package framework + +import ( + bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" +) + +const ( + RequestPluginChain = "REQUEST_PLUGINS_CHAIN" + ResponsePluginChain = "RESPONSE_PLUGINS_CHAIN" +) + +// placeholder for Plugin constructors +type PluginFactoryFunc func() bbrplugins.BBRPlugin //any no-argument function that returns bbrplugins.BBRPlugin can be assigned to this type (including a constructor function) + +// ----------------------- Registry Interface -------------------------------------------------- +// PluginRegistry defines operations for managing plugin factories and plugin instances +type PluginRegistry interface { + RegisterFactory(typeKey string, factory PluginFactoryFunc) error //constructors + RegisterPlugin(plugin bbrplugins.BBRPlugin) error //registers a plugin instance (the instance is supposed to be created via the factory first) + GetFactory(typeKey string) (PluginFactoryFunc, error) + GetPlugin(typeKey string) (bbrplugins.BBRPlugin, error) + GetFactories() map[string]PluginFactoryFunc + GetPlugins() map[string]bbrplugins.BBRPlugin + ListPlugins() []string + ListFactories() []string + CreatePlugin(typeKey string) (bbrplugins.BBRPlugin, error) + ContainsFactory(typeKey string) bool + ContainsPlugin(typeKey string) bool + String() string +} + +// ------------------------ Ordered Plugins Interface ------------------------------------------ +// PluginsChain is used to define a specific order of execution of the plugin instances stored in the registry +type PluginsChain interface { + AddPlugin(typeKey string, registry PluginRegistry) error //to be added to the chain the plugin should be registered in the registry first + AddPluginAtInd(typeKey string, i int, r PluginRegistry) error //only affects the instance of the plugin chain + GetPlugin(index int, registry PluginRegistry) (bbrplugins.BBRPlugin, error) //retrieves i-th plugin as defined in the chain from the registry + Length() int + GetPlugins() []string + Run(bodyBytes []byte, registry PluginRegistry) (map[string]string, []byte, error) //return potentially mutated body and all headers map safely merged + String() string +} diff --git a/pkg/bbr/framework/registry.go b/pkg/bbr/framework/registry.go new file mode 100644 index 000000000..4386cdde2 --- /dev/null +++ b/pkg/bbr/framework/registry.go @@ -0,0 +1,276 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package framework + +import ( + "fmt" + + bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" +) + +// -------------------- INTERFACES ----------------------------------------------------------------------- +// Interfaces are defined in "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework/interfaces.go" + +// --------------------- PluginRegistry implementation --------------------------------------------------- + +// pluginRegistry implements PluginRegistry +type pluginRegistry struct { + pluginsFactory map[string]PluginFactoryFunc //constructors + plugins map[string]bbrplugins.BBRPlugin // instances +} + +// NewPluginRegistry creates a new instance of pluginRegistry +func NewPluginRegistry() PluginRegistry { + return &pluginRegistry{ + pluginsFactory: make(map[string]PluginFactoryFunc), + plugins: make(map[string]bbrplugins.BBRPlugin), + } +} + +// Register a plugin factory by type key (e.g., "ModelSelector", "MetadataExtractor") +func (r *pluginRegistry) RegisterFactory(typeKey string, factory PluginFactoryFunc) error { + //validate whether already registered + alreadyRegistered := r.ContainsFactory(typeKey) + if alreadyRegistered { + err := fmt.Errorf("factory fot plugin interface type %s is already registered", typeKey) + return err + } + r.pluginsFactory[typeKey] = factory + + return nil +} + +// Register a plugin instance (created through the appropriate factory) +func (r *pluginRegistry) RegisterPlugin(plugin bbrplugins.BBRPlugin) error { + //validate whether this interface is supported + alreadyRegistered := r.ContainsPlugin(plugin.TypedName().Type) + + if alreadyRegistered { + err := fmt.Errorf("plugin implementing interface type %s is already registered", plugin.TypedName().Type) + return err + } + + // validate that the factory for this plugin is registered: always register factory before the plugin + if _, ok := r.pluginsFactory[plugin.TypedName().Type]; !ok { + err := fmt.Errorf("no plugin factory registered for plugin interface type %s", plugin.TypedName().Type) + return err + } + r.plugins[plugin.TypedName().Type] = plugin + + return nil +} + +// Retrieves a plugin factory by type key +func (r *pluginRegistry) GetFactory(typeKey string) (PluginFactoryFunc, error) { + if pluginFactory, ok := r.pluginsFactory[typeKey]; ok { + return pluginFactory, nil + } + return nil, fmt.Errorf("plugin type %s not found", typeKey) +} + +// Retrieves a plugin instance by type key +func (r *pluginRegistry) GetPlugin(typeKey string) (bbrplugins.BBRPlugin, error) { + if plugin, ok := r.plugins[typeKey]; ok { + return plugin, nil + } + return nil, fmt.Errorf("plugin type %s not found", typeKey) +} + +// Constructs a new plugin (a caller can perform either type assertion of a concrete implementation of the BBR plugin) +func (r *pluginRegistry) CreatePlugin(typeKey string) (bbrplugins.BBRPlugin, error) { + if factory, ok := r.pluginsFactory[typeKey]; ok { + plugin := factory() + return plugin, nil + } + return nil, fmt.Errorf("plugin %s not registered", typeKey) +} + +// Removes a plugin factory by type key +func (r *pluginRegistry) UnregisterFactory(typeKey string) error { + if _, ok := r.pluginsFactory[typeKey]; ok { + delete(r.pluginsFactory, typeKey) + return nil + } + return fmt.Errorf("plugin (%s) not found", typeKey) +} + +// ListPlugins lists all registered plugins +func (r *pluginRegistry) ListPlugins() []string { + typeKeys := make([]string, 0, len(r.plugins)) + for k := range r.plugins { + typeKeys = append(typeKeys, k) + } + return typeKeys +} + +// ListPlugins lists all registered plugins; this functionis not really needed. Just for sanity checks and tests +func (r *pluginRegistry) ListFactories() []string { + typeKeys := make([]string, 0, len(r.pluginsFactory)) + for k := range r.pluginsFactory { + typeKeys = append(typeKeys, k) + } + return typeKeys +} + +// Get factories +func (r *pluginRegistry) GetFactories() map[string]PluginFactoryFunc { + return r.pluginsFactory +} + +// Get plugins +func (r *pluginRegistry) GetPlugins() map[string]bbrplugins.BBRPlugin { + return r.plugins +} + +// Checks for presense of a factory in this registry +func (r *pluginRegistry) ContainsFactory(typeKey string) bool { + _, exists := r.pluginsFactory[typeKey] + return exists +} + +// Helper: Checks for presense of a plugin in this registry +func (r *pluginRegistry) ContainsPlugin(typeKey string) bool { + _, exists := r.plugins[typeKey] + return exists +} + +func (r *pluginRegistry) String() string { + return fmt.Sprintf("{plugins=%v}{pluginsFactory=%v}", r.plugins, r.pluginsFactory) +} + +//-------------------------- PluginsChain implementation -------------------------- + +// PluginsChain is a sequence of plugins to be executed in order inside the ext_proc server +type pluginsChain struct { + plugins []string +} + +// NewPluginsChain creates a new PluginsChain instance +func NewPluginsChain() PluginsChain { + return &pluginsChain{ + plugins: []string{}, + } +} + +// AddPlugin adds a plugin to the chain +func (pc *pluginsChain) AddPlugin(typeKey string, r PluginRegistry) error { + // check whether this plugin was registered in the registry (i.e., the factory for the plugin exist and an instance was created) + if ok := r.ContainsPlugin(typeKey); !ok { + err := fmt.Errorf("plugin type %s not found", typeKey) + return err + } + pc.plugins = append(pc.plugins, typeKey) + + return nil +} + +// GetPlugin retrieves the next plugin in the chain by index +func (pc *pluginsChain) GetPlugin(index int, r PluginRegistry) (bbrplugins.BBRPlugin, error) { + if index < 0 || index >= len(pc.plugins) { + return nil, fmt.Errorf("plugin index %d out of range", index) + } + plugins := r.GetPlugins() + plugin, ok := plugins[pc.plugins[index]] + if !ok { + return nil, fmt.Errorf("plugin index %d is not found in the registry", index) + } + return plugin, nil +} + +// Length returns the number of plugins in the chain +func (pc *pluginsChain) Length() int { + return len(pc.plugins) +} + +// AddPluginInOrder inserts a plugin into the chain in the specified index +func (pc *pluginsChain) AddPluginAtInd(typeKey string, i int, r PluginRegistry) error { + if i < 0 || i > len(pc.plugins) { + return fmt.Errorf("index %d is out of range", i) + } + // validate that the plugin is registered + plugins := r.GetPlugins() + if _, ok := plugins[pc.plugins[i]]; !ok { + return fmt.Errorf("plugin index %d is not found in the registry", i) + } + pc.plugins = append(pc.plugins[:i], append([]string{typeKey}, pc.plugins[i:]...)...) + return nil +} + +func (pc *pluginsChain) GetPlugins() []string { + return pc.plugins +} + +// MergeMaps copies all key/value pairs from src into dst and returns dst. +// If dst is nil a new map is allocated. +// Existing keys in dst are not overwritten. +// This is a helper function used to merge headers from multiple plugins safely. +func MergeMaps(dst map[string]string, src map[string]string) map[string]string { + if src == nil { + if dst == nil { + return map[string]string{} + } + return dst + } + if dst == nil { + dst = make(map[string]string, len(src)) + } + + for k, v := range src { + if _, exists := dst[k]; !exists { + dst[k] = v + } + } + + return dst +} + +func (pc *pluginsChain) Run( + bodyBytes []byte, + r PluginRegistry, +) (headers map[string]string, mutateBodyBytes []byte, err error) { + + allHeaders := make(map[string]string) + mutatedBodyBytes := bodyBytes + + for i := range pc.Length() { + plugin, _ := pc.GetPlugin(i, r) + pluginType := plugin.TypedName().Type + + metExtPlugin, err := r.GetPlugin(pluginType) + + if err != nil { + return allHeaders, bodyBytes, err + } + + // The plugin i in the chain receives the (potentially mutated) body and headers from plugin i-1 in the chain + headers, mutatedBodyBytes, err := metExtPlugin.Execute(mutatedBodyBytes) + + if err != nil { + return headers, mutatedBodyBytes, err + } + + //note that the existing overlapping keys are NOT over-written by merge + MergeMaps(allHeaders, headers) + } + return allHeaders, mutatedBodyBytes, nil +} + +func (pc *pluginsChain) String() string { + return fmt.Sprintf("PluginsChain{plugins=%v}", pc.plugins) +} + +// -------------------------- End of PluginsChain implementation -------------------------- diff --git a/pkg/bbr/handlers/request.go b/pkg/bbr/handlers/request.go index e7baec6ef..9a73e47f5 100644 --- a/pkg/bbr/handlers/request.go +++ b/pkg/bbr/handlers/request.go @@ -18,54 +18,50 @@ package handlers import ( "context" - "encoding/json" + "strings" basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics" + bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" + helpers "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/utils" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -const modelHeader = "X-Gateway-Model-Name" - -type RequestBody struct { - Model string `json:"model"` -} - // HandleRequestBody handles request bodies. func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte) ([]*eppb.ProcessingResponse, error) { logger := log.FromContext(ctx) var ret []*eppb.ProcessingResponse - var requestBody RequestBody - if err := json.Unmarshal(requestBodyBytes, &requestBody); err != nil { - metrics.RecordModelNotParsedCounter() - return nil, err + allHeaders, mutatedBodyBytes, err := s.requestChain.Run(requestBodyBytes, s.registry) + + if err != nil { + //TODO: add metric in metrics.go to count "other errors" + logger.V(logutil.DEFAULT).Info("error processing body", "error", err) + ret, _ := buildEmptyResponsesForMissingModel(s.streaming, requestBodyBytes) + return ret, nil + } + + model, ok := allHeaders[bbrplugins.ModelHeader] + + if !ok { + //TODO: add metric in metrics.go to count "other errors" + logger.V(logutil.DEFAULT).Info("manadatory header X-Gateway-Model-Name value is undetermined") + ret, _ := buildEmptyResponsesForMissingModel(s.streaming, requestBodyBytes) + return ret, nil } - if requestBody.Model == "" { + if strings.TrimSpace(model) == "" { metrics.RecordModelNotInBodyCounter() - logger.V(logutil.DEFAULT).Info("Request body does not contain model parameter") - if s.streaming { - ret = append(ret, &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &eppb.HeadersResponse{}, - }, - }) - ret = addStreamedBodyResponse(ret, requestBodyBytes) - return ret, nil - } else { - ret = append(ret, &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestBody{ - RequestBody: &eppb.BodyResponse{}, - }, - }) - } + ret, _ := buildEmptyResponsesForMissingModel(s.streaming, requestBodyBytes) return ret, nil } + //TODO: change to DEBUG + logger.V(logutil.DEFAULT).Info("model extracted from request body", "model", model) + metrics.RecordSuccessCounter() if s.streaming { @@ -78,8 +74,8 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte) SetHeaders: []*basepb.HeaderValueOption{ { Header: &basepb.HeaderValue{ - Key: modelHeader, - RawValue: []byte(requestBody.Model), + Key: bbrplugins.ModelHeader, + RawValue: []byte(model), }, }, }, @@ -88,7 +84,11 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte) }, }, }) - ret = addStreamedBodyResponse(ret, requestBodyBytes) + ret = addStreamedBodyResponse(ret, mutatedBodyBytes) + + //TODO: change to DEBUG + logger.V(logutil.DEFAULT).Info("RESPONSE", "response", helpers.PrettyPrintResponses(ret)) + return ret, nil } @@ -103,12 +103,17 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte) SetHeaders: []*basepb.HeaderValueOption{ { Header: &basepb.HeaderValue{ - Key: modelHeader, - RawValue: []byte(requestBody.Model), + Key: bbrplugins.ModelHeader, + RawValue: []byte(model), }, }, }, }, + BodyMutation: &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: mutatedBodyBytes, + }, + }, }, }, }, @@ -116,7 +121,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte) }, nil } -func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, requestBodyBytes []byte) []*eppb.ProcessingResponse { +func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, mutatedBodyBytes []byte) []*eppb.ProcessingResponse { return append(responses, &eppb.ProcessingResponse{ Response: &eppb.ProcessingResponse_RequestBody{ RequestBody: &eppb.BodyResponse{ @@ -124,7 +129,7 @@ func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, requestBodyBy BodyMutation: &eppb.BodyMutation{ Mutation: &eppb.BodyMutation_StreamedResponse{ StreamedResponse: &eppb.StreamedBodyResponse{ - Body: requestBodyBytes, + Body: mutatedBodyBytes, EndOfStream: true, }, }, @@ -156,3 +161,31 @@ func (s *Server) HandleRequestTrailers(trailers *eppb.HttpTrailers) ([]*eppb.Pro }, }, nil } + +// buildEmptyResponsesForMissingModel is a local helper that returns the appropriate empty responses +// for the "model not found" branch depending on streaming mode. +// It is also used to create empty responses in case of other errors related to running plugins on the body +// This is not very clean and MUST be segregated in the future. +// Corresponding metrics should be defined to make different errors observable +func buildEmptyResponsesForMissingModel(streaming bool, requestBodyBytes []byte) ([]*eppb.ProcessingResponse, error) { + var ret []*eppb.ProcessingResponse + + if streaming { + // Emit empty headers response, then stream body unchanged. + ret = append(ret, &eppb.ProcessingResponse{ + Response: &eppb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &eppb.HeadersResponse{}, + }, + }) + ret = addStreamedBodyResponse(ret, requestBodyBytes) + return ret, nil + } + + // Non-streaming: emit empty body response. + ret = append(ret, &eppb.ProcessingResponse{ + Response: &eppb.ProcessingResponse_RequestBody{ + RequestBody: &eppb.BodyResponse{}, + }, + }) + return ret, nil +} diff --git a/pkg/bbr/handlers/request_test.go b/pkg/bbr/handlers/request_test.go index 9e408fdef..4ae6d46b9 100644 --- a/pkg/bbr/handlers/request_test.go +++ b/pkg/bbr/handlers/request_test.go @@ -30,6 +30,7 @@ import ( crmetrics "sigs.k8s.io/controller-runtime/pkg/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics" + utils "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/utils" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -176,9 +177,15 @@ func TestHandleRequestBody(t *testing.T) { }, } + //Initialize PluginRegistry and request/response PluginsChain instances based on the minimal configuration setting vi env vars + registry, requestChain, responseChain, err := utils.InitPlugins() + if err != nil { + t.Fatalf("processRequestBody(): %v", err) + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - server := &Server{streaming: test.streaming} + server := NewServer(test.streaming, registry, requestChain, responseChain) bodyBytes, _ := json.Marshal(test.body) resp, err := server.HandleRequestBody(ctx, bodyBytes) if err != nil { diff --git a/pkg/bbr/handlers/server.go b/pkg/bbr/handlers/server.go index 6488453be..a8128a5b9 100644 --- a/pkg/bbr/handlers/server.go +++ b/pkg/bbr/handlers/server.go @@ -27,18 +27,29 @@ import ( "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) -func NewServer(streaming bool) *Server { - return &Server{streaming: streaming} +func NewServer(streaming bool, + reg framework.PluginRegistry, + reqChain framework.PluginsChain, + respChain framework.PluginsChain) *Server { + return &Server{streaming: streaming, + registry: reg, + requestChain: reqChain, + responseChain: respChain, + } } // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto type Server struct { - streaming bool + streaming bool + registry framework.PluginRegistry + requestChain framework.PluginsChain + responseChain framework.PluginsChain } func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { diff --git a/pkg/bbr/handlers/server_test.go b/pkg/bbr/handlers/server_test.go index 7bc50a697..a02ceaed0 100644 --- a/pkg/bbr/handlers/server_test.go +++ b/pkg/bbr/handlers/server_test.go @@ -26,12 +26,19 @@ import ( "google.golang.org/protobuf/testing/protocmp" "sigs.k8s.io/controller-runtime/pkg/log" + bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" + utils "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/utils" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) func TestProcessRequestBody(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) + //set environment variables expected by the code under test + //testing a minimal configuration + //request plugin chain always contains the default bbr plugin that extracts a model name and sets it on the X-Gateway-Model-Name header + t.Setenv("REQUEST_PLUGINS_CHAIN", "simple_model_extractor") + cases := []struct { desc string streaming bool @@ -58,7 +65,7 @@ func TestProcessRequestBody(t *testing.T) { SetHeaders: []*basepb.HeaderValueOption{ { Header: &basepb.HeaderValue{ - Key: modelHeader, + Key: bbrplugins.ModelHeader, RawValue: []byte("foo"), }, }, @@ -93,7 +100,7 @@ func TestProcessRequestBody(t *testing.T) { SetHeaders: []*basepb.HeaderValueOption{ { Header: &basepb.HeaderValue{ - Key: modelHeader, + Key: bbrplugins.ModelHeader, RawValue: []byte("foo"), }, }, @@ -125,9 +132,15 @@ func TestProcessRequestBody(t *testing.T) { }, } + //Initialize PluginRegistry and request/response PluginsChain instances based on the minimal configuration setting vi env vars + registry, requestChain, responseChain, err := utils.InitPlugins() + if err != nil { + t.Fatalf("processRequestBody(): %v", err) + } + for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - srv := NewServer(tc.streaming) + srv := NewServer(tc.streaming, registry, requestChain, responseChain) streamedBody := &streamedBody{} for i, body := range tc.bodys { got, err := srv.processRequestBody(context.Background(), body, streamedBody, log.FromContext(ctx)) diff --git a/pkg/bbr/plugins/interfaces.go b/pkg/bbr/plugins/interfaces.go new file mode 100644 index 000000000..42eb3d22b --- /dev/null +++ b/pkg/bbr/plugins/interfaces.go @@ -0,0 +1,67 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bbrplugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" +) + +// ------------------------------------ Defaults ------------------------------------------ +const ( + //The deafult plugin implementation of this plugin type will always be configured for request plugins chain + //Even though BBRPlugin type is not a K8s resource, it's logically akin to `kind` + //Shoud start wit an upper case letter, use CamelNotation, only aplhanumericals after the first letter + PluginTypePattern = `^[A-Z][A-Za-z0-9]*$` + MaxPluginTypeLength = 63 + DefaultPluginType = "MetaDataExtractor" + // Even though BBRPlugin is not a K8s resource yet, let's make its naming compliant with K8s resource naming + // Allows: lowercase letters, digits, hyphens, dots. + // Must start and end with a lowercase alphanumeric character. + // Middle characters group can contain lowercase alphanumerics, hyphens, and dots + // Middle and rightmost groups are optional + PluginNamePattern = `^[a-z0-9]([-a-z0-9.]*[a-z0-9])?$` + DefaultPluginName = "simple-model-extractor" + MaxPluginNameLength = 253 + //Well-known custom header set to a model name + ModelHeader = "X-Gateway-Model-Name" +) + +// BBRPlugin defines the interface for plugins in the BBR framework should never mutate the body directly. +type BBRPlugin interface { + plugins.Plugin + + // RequiresFullParsing indicates whether full body parsing is required + // to facilitate efficient memory sharing across plugins in a chain. + RequiresFullParsing() bool + + // Execute runs the plugin logic on the request body. + // A plugin's imnplementation logic CAN mutate the body of the message. + // A plugin's implementation MUST return a map of headers + // If no headers are set by the implementation, the map must be empty + // A value of a header in an extended implementation NEED NOT to be identical to the value of that same header as would be set + // in a default implementation. + // Example: in the body of a request model is set to "semantic-model-selector", + // which, say, stands for "select a best model for this request at minimal cost" + // A plugin implementation of "semantic-model-selector" sets X-Gateway-Model-Name to any valid + // model name from the inventory of the backend models and also mutates the body accordingly + // In contrast, + Execute(requestBodyBytes []byte) ( + headers map[string]string, + mutatedBodyBytes []byte, + err error, + ) +} diff --git a/pkg/bbr/plugins/simple_model_extractor.go b/pkg/bbr/plugins/simple_model_extractor.go new file mode 100644 index 000000000..d3ede0de4 --- /dev/null +++ b/pkg/bbr/plugins/simple_model_extractor.go @@ -0,0 +1,97 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bbrplugins + +import ( + "encoding/json" + + "fmt" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" +) + +// ------------------------------------ INTERFACES --------------------------------------------------------------- +// Interfaces are defined in "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework/plugins/interfaces.go" +// ---------------------------------------------------------------------------------------------------------------- + +// ------------------------------------ DEFAULT PLUGIN IMPLEMENTATION ---------------------------------------------- + +// defaultMetaDataExtractor implements the MetadataExtractor interface and extracts only the mmodel name AS-IS +type defaultMetaDataExtractor struct { + typedName plugins.TypedName + requiresFullParsing bool //this field will be used to determine whether shared struct should be created in this chain +} + +// NewSimpleModelExtractor is a factory that constructs SimpleModelExtractor plugin +// A developer who wishes to create her own implementation, will implement the BBRPlugin interface and +// use Registry and PluginsChain to register and execute the plugin (together with other plugins in a chain) +func NewDefaultMetaDataExtractor() BBRPlugin { + return &defaultMetaDataExtractor{ + typedName: plugins.TypedName{ + Type: DefaultPluginType, + Name: "simple-model-extractor", + }, + requiresFullParsing: false, + } +} + +func (s *defaultMetaDataExtractor) RequiresFullParsing() bool { + return s.requiresFullParsing +} + +func (s *defaultMetaDataExtractor) TypedName() plugins.TypedName { + return s.typedName +} + +// Execute extracts the "model" from the JSON request body and sets X-Gateway-Model-Name header. +// This implementation intentionally ignores metaDataKeys and does not mutate the body. +// It expects the request body to be a JSON object containing a "model" field. +// A nil for metaDataKeysToHeaders map SHOULD be specified by a caller for clarity +// The metaDataKeysToHeaders is explicitly ignored in this implementation +// This implementation is simply refactoring of the default BBR implementation to work with the pluggable framework +func (s *defaultMetaDataExtractor) Execute(requestBodyBytes []byte) ( + headers map[string]string, + mutatedBodyBytes []byte, + err error) { + + type RequestBody struct { + Model string `json:"model"` + } + + h := make(map[string]string) + + var requestBody RequestBody + + if err := json.Unmarshal(requestBodyBytes, &requestBody); err != nil { + // return original body on decode failure + return nil, requestBodyBytes, err + } + + if requestBody.Model == "" { + return nil, requestBodyBytes, fmt.Errorf("missing required field: model") + } + + // ModelHeader is a constant defined in ./pkg/bbr/plugins/interfaces + h[ModelHeader] = requestBody.Model + + // Body is not mutated in this implementation hence returning original requestBodyBytes. This is intentional. + return h, requestBodyBytes, nil +} + +func (s *defaultMetaDataExtractor) String() string { + return fmt.Sprintf(("BBRPlugin{%v/requiresFullParsing=%v}"), s.TypedName(), s.requiresFullParsing) +} diff --git a/pkg/bbr/server/runserver.go b/pkg/bbr/server/runserver.go index ac6ac414e..31abb28b0 100644 --- a/pkg/bbr/server/runserver.go +++ b/pkg/bbr/server/runserver.go @@ -29,21 +29,33 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" tlsutil "sigs.k8s.io/gateway-api-inference-extension/internal/tls" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/handlers" ) // ExtProcServerRunner provides methods to manage an external process server. type ExtProcServerRunner struct { - GrpcPort int - SecureServing bool - Streaming bool + GrpcPort int + SecureServing bool + Streaming bool + registry framework.PluginRegistry + requestPluginsChain framework.PluginsChain + responsePluginsChain framework.PluginsChain } -func NewDefaultExtProcServerRunner(port int, streaming bool) *ExtProcServerRunner { +func NewDefaultExtProcServerRunner( + port int, + streaming bool, + r framework.PluginRegistry, + reqChain framework.PluginsChain, + respChain framework.PluginsChain) *ExtProcServerRunner { return &ExtProcServerRunner{ - GrpcPort: port, - SecureServing: true, - Streaming: streaming, + GrpcPort: port, + SecureServing: true, + Streaming: streaming, + registry: r, + requestPluginsChain: reqChain, + responsePluginsChain: respChain, } } @@ -65,7 +77,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { extProcPb.RegisterExternalProcessorServer( srv, - handlers.NewServer(r.Streaming), + handlers.NewServer(r.Streaming, r.registry, r.requestPluginsChain, r.responsePluginsChain), ) // Forward to the gRPC runnable. diff --git a/pkg/bbr/utils/factory_helper.go b/pkg/bbr/utils/factory_helper.go new file mode 100644 index 000000000..8cd21e329 --- /dev/null +++ b/pkg/bbr/utils/factory_helper.go @@ -0,0 +1,43 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + framework "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" + bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" +) + +// RegisterAllFactories registers all factories for all plugins implementing BBRPlugin interface +// As more plugins are developed, this function is extended to make sure that the registry is bootstrapped to have +// all constructors upfront upon initialization when utils.init() is called. +// Whether instances of these plugins are created, depends on how the plugin chains for request and response are configured. +// By default, a request plugins chain is always present, containing "MetaDataExtractor/simple-model-extractor" plugin +// That pulls model from the body into the X-Gateway-Model-Name header +// In extended implementations (outside of IGW), if there are plugins, which are not upstream in IGW, +// this function should be extended to include factories for the plugins unknown to the base IGW code +func RegisterAllFactories(registry framework.PluginRegistry) error { + //default plugin factory registration + err := registry.RegisterFactory(bbrplugins.DefaultPluginType, + func() bbrplugins.BBRPlugin { + return bbrplugins.NewDefaultMetaDataExtractor() + }) + if err != nil { + return err + } + //another plugin factory registration here. etc. + return nil +} diff --git a/pkg/bbr/utils/helpers.go b/pkg/bbr/utils/helpers.go new file mode 100644 index 000000000..eeae18bea --- /dev/null +++ b/pkg/bbr/utils/helpers.go @@ -0,0 +1,91 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/encoding/protojson" + + eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + "encoding/json" +) + +// PrettyPrintResponses returns a human-readable string with: +// - Full JSON representation of each response +// - Decoded headers and body content (pretty-printed if JSON) +func PrettyPrintResponses(responses []*eppb.ProcessingResponse) string { + var builder strings.Builder + + for i, resp := range responses { + // Marshal protobuf to JSON + jsonBytes, err := protojson.MarshalOptions{ + Multiline: true, + Indent: " ", + }.Marshal(resp) + if err != nil { + builder.WriteString(fmt.Sprintf("Error marshaling response %d: %v\n", i, err)) + continue + } + + builder.WriteString(fmt.Sprintf("\n=== Response %d ===\n", i)) + builder.WriteString(string(jsonBytes)) + builder.WriteString("\n") + + // Decode headers + if headers := resp.GetRequestHeaders(); headers != nil { + builder.WriteString("\nDecoded Headers:\n") + for _, h := range headers.GetResponse().GetHeaderMutation().GetSetHeaders() { + key := h.Header.Key + raw := h.Header.RawValue + if len(raw) > 0 { + decoded := string(raw) // RawValue is []byte, safe to convert + builder.WriteString(fmt.Sprintf(" %s: %s\n", key, decoded)) + } + } + } + + // Decode body + if body := resp.GetRequestBody(); body != nil { + mutation := body.GetResponse().GetBodyMutation() + if mutation != nil { + if streamed := mutation.GetStreamedResponse(); streamed != nil { + builder.WriteString("\nDecoded Streamed Body:\n") + builder.WriteString(prettyIfJSON(streamed.Body)) + } else if raw := mutation.GetBody(); len(raw) > 0 { + builder.WriteString("\nDecoded Body:\n") + builder.WriteString(prettyIfJSON(raw)) + } + } + } + builder.WriteString("\n====================\n") + } + + return builder.String() +} + +// prettyIfJSON tries to pretty-print JSON if valid, else returns raw text +func prettyIfJSON(data []byte) string { + var obj interface{} + if err := json.Unmarshal(data, &obj); err == nil { + pretty, _ := json.MarshalIndent(obj, " ", " ") + return string(pretty) + "\n" + } + return string(data) + "\n" +} diff --git a/pkg/bbr/utils/init.go b/pkg/bbr/utils/init.go new file mode 100644 index 000000000..9cf204768 --- /dev/null +++ b/pkg/bbr/utils/init.go @@ -0,0 +1,147 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Initializes PluginRegistry from environment variables (as set in the helm chart) + +package utils + +import ( + "fmt" + "os" + "regexp" + "strings" + + framework "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" + bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" +) + +func InitPlugins() ( + framework.PluginRegistry, + framework.PluginsChain, + framework.PluginsChain, + error) { + + //The environment variables defining plugins repertoire and plugin chains are set via Helm chart + registry := framework.NewPluginRegistry() + requestChain := framework.NewPluginsChain() + responseChain := framework.NewPluginsChain() + + //Define a standardized regex patterns for plugin types + var pluginTypeRe = regexp.MustCompile(bbrplugins.PluginTypePattern) + + //Define a standardized regex pattern for plugin names + var pluginNameRe = regexp.MustCompile(bbrplugins.PluginNamePattern) + + //helper to validate plugin name + isValidPluginName := func(name string) bool { + if len(name) == 0 || len(name) > bbrplugins.MaxPluginNameLength { + return false + } + return pluginNameRe.MatchString(name) + } + + //helper to validate plugin type name + isValidPluginType := func(name string) bool { + if len(name) == 0 || len(name) > bbrplugins.MaxPluginTypeLength { + return false + } + return pluginTypeRe.MatchString(name) + } + + //helper to process plugins + processPlugin := func(pluginType string, chain framework.PluginsChain) error { + + //create the plugin instance + plugin, err := registry.CreatePlugin(pluginType) + if err != nil { + return fmt.Errorf("failed to create an instance of %s %v", pluginType, err) + } + + //register the plugin instance + err = registry.RegisterPlugin(plugin) + if err != nil { + return fmt.Errorf("failed to register an instance of %s %v", pluginType, err) + } + + //Add plugin type name to the pluginsChain instance + err = chain.AddPlugin(pluginType, registry) + if err != nil { + return fmt.Errorf("failed to add plugin instance %s %v", pluginType, err) + } + return nil + } + + // Helper to process plugin chains + processChain := func(envVar string, chain framework.PluginsChain) error { + envPluginsChain := os.Getenv(envVar) + + if envPluginsChain == "" { + return nil // no plugins defined for this chain, but this is not an error + } + + //Plugins are specified as, e.g., REQUEST_PLUGINS_CHAIN=MetaDataExtractor:simple-model-extractor, MyPluginType:my-plugin-name + parts := strings.Split(envPluginsChain, ",") + + for i, part := range parts { + typedname := strings.TrimSpace(part) + + subparts := strings.Split(typedname, ":") + pluginType := subparts[0] + pluginName := subparts[1] + + //validate plugin type naming rules + if !isValidPluginType(pluginType) { + return fmt.Errorf("plugin %d: invalid type %s", i, pluginType) + } + + //validate plugin naming rules + if !isValidPluginName(pluginName) { + return fmt.Errorf("plugin %d: invalid type %s", i, pluginType) + } + + //process this plugin: create an instance, register it in a registry , and add to plugin chain by type name + if err := processPlugin(pluginType, chain); err != nil { + return fmt.Errorf("failed to install plugin %d: %s", i, pluginType) + } + } + return nil + } + + // Pre-register all BBRPlugin factories factories + if err := RegisterAllFactories(registry); err != nil { + return nil, nil, nil, err + } + + // Process request plugins chain + if err := processChain(framework.RequestPluginChain, requestChain); err != nil { //requestPlugins chain need not be explicitly specified if the only pluginn is the default one + return nil, nil, nil, err + } + + // Process response plugins chain + if err := processChain(framework.ResponsePluginChain, responseChain); err != nil { //responsePluginsChain is currently left empty + return nil, nil, nil, err + } + + // If request chain is empty (i.e., it was not explicitly specified in the env via Helm), add default MetadataExtractor + if requestChain.Length() == 0 { + //use default plugin + if err := processPlugin(bbrplugins.DefaultPluginType, requestChain); err != nil { + return nil, nil, nil, fmt.Errorf("failed to create default MetaDataExtractor: %v", err) + } + } + + return registry, requestChain, responseChain, nil +} diff --git a/test/integration/bbr/harness_test.go b/test/integration/bbr/harness_test.go index b55d9e345..97b9439a4 100644 --- a/test/integration/bbr/harness_test.go +++ b/test/integration/bbr/harness_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/credentials/insecure" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/server" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/utils" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" "sigs.k8s.io/gateway-api-inference-extension/test/integration" ) @@ -60,7 +61,16 @@ func NewBBRHarness(t *testing.T, ctx context.Context, streaming bool) *BBRHarnes // 2. Configure BBR Server // BBR is simpler than EPP; it doesn't need a K8s Manager. - runner := runserver.NewDefaultExtProcServerRunner(port, false) + + // 2.1 Configure pluggable BBR framework + //Initialize PluginRegistry and request/response PluginsChain instances based on the minimal configuration setting vi env vars + registry, requestChain, responseChain, err := utils.InitPlugins() + if err != nil { + logutil.Fatal(logger, err, "failed to initialize BBR pluggable framework: %v", err) + } + + runner := runserver.NewDefaultExtProcServerRunner(port, false, registry, requestChain, responseChain, metaDataKeys) + //runner := runserver.NewDefaultExtProcServerRunner(port, false) runner.SecureServing = false runner.Streaming = streaming diff --git a/test/integration/bbr/hermetic_test.go b/test/integration/bbr/hermetic_test.go index 3edcb70f4..6e167b6bf 100644 --- a/test/integration/bbr/hermetic_test.go +++ b/test/integration/bbr/hermetic_test.go @@ -19,12 +19,19 @@ package bbr import ( "context" + "fmt" "testing" + "time" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/testing/protocmp" + runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/server" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/utils" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" "sigs.k8s.io/gateway-api-inference-extension/test/integration" ) @@ -137,3 +144,46 @@ func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) { }) } } + +func setUpHermeticServer(streaming bool) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { + port := 9004 + + serverCtx, stopServer := context.WithCancel(context.Background()) + + //Initialize PluginRegistry and request/response PluginsChain instances based on the minimal configuration setting vi env vars + registry, requestChain, responseChain, err := utils.InitPlugins() + if err != nil { + logutil.Fatal(logger, err, "failed to initialize BBR pluggable framework: %v", err) + } + + serverRunner := runserver.NewDefaultExtProcServerRunner(port, false, registry, requestChain, responseChain) + serverRunner.SecureServing = false + serverRunner.Streaming = streaming + + go func() { + if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { + logutil.Fatal(logger, err, "Failed to start ext-proc server") + } + }() + + address := fmt.Sprintf("localhost:%v", port) + // Create a grpc connection + conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + logutil.Fatal(logger, err, "Failed to connect", "address", address) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + client, err = extProcPb.NewExternalProcessorClient(conn).Process(ctx) + if err != nil { + logutil.Fatal(logger, err, "Failed to create client") + } + return client, func() { + cancel() + conn.Close() + stopServer() + + // wait a little until the goroutines actually exit + time.Sleep(5 * time.Second) + } +}