Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ignasi Barrera <[email protected]>
  • Loading branch information
nacx committed Feb 12, 2025
1 parent 6bdf00d commit e3f0359
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 80 deletions.
44 changes: 16 additions & 28 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io"
"log/slog"
"unicode/utf8"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
Expand All @@ -19,12 +18,16 @@ import (
)

// NewChatCompletionProcessor implements [ProcessorIface] for the /chat/completions endpoint.
func NewChatCompletionProcessor(config *processorConfig, logger *slog.Logger) ProcessorIface {
return &ChatCompletionProcessor{config: config, logger: logger}
func NewChatCompletionProcessor(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger) ProcessorIface {
return &chatCompletionProcessor{
config: config,
requestHeaders: requestHeaders,
logger: logger,
}
}

// ChatCompletionProcessor handles the processing of the request and response messages for a single stream.
type ChatCompletionProcessor struct {
// chatCompletionProcessor handles the processing of the request and response messages for a single stream.
type chatCompletionProcessor struct {
logger *slog.Logger
config *processorConfig
requestHeaders map[string]string
Expand All @@ -36,16 +39,15 @@ type ChatCompletionProcessor struct {
}

// ProcessRequestHeaders implements [ProcessorIface.ProcessRequestHeaders].
func (p *ChatCompletionProcessor) ProcessRequestHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
p.requestHeaders = headersToMap(headers)
resp := &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
func (p *chatCompletionProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
// The request headers have already been at the time the processor was created
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{},
}}
return resp, nil
}}, nil
}

// ProcessRequestBody implements [ProcessorIface.ProcessRequestBody].
func (p *ChatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := p.requestHeaders[":path"]
model, body, err := p.config.bodyParser(path, rawBody)
if err != nil {
Expand Down Expand Up @@ -108,7 +110,7 @@ func (p *ChatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBod
}

// ProcessResponseHeaders implements [ProcessorIface.ProcessResponseHeaders].
func (p *ChatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
func (p *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
p.responseHeaders = headersToMap(headers)
if enc := p.responseHeaders["content-encoding"]; enc != "" {
p.responseEncoding = enc
Expand All @@ -132,7 +134,7 @@ func (p *ChatCompletionProcessor) ProcessResponseHeaders(_ context.Context, head
}

// ProcessResponseBody implements [ProcessorIface.ProcessResponseBody].
func (p *ChatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
var br io.Reader
switch p.responseEncoding {
case "gzip":
Expand Down Expand Up @@ -178,7 +180,7 @@ func (p *ChatCompletionProcessor) ProcessResponseBody(_ context.Context, body *e
return resp, nil
}

func (p *ChatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
func (p *chatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(p.config.requestCosts))
for i := range p.config.requestCosts {
c := &p.config.requestCosts[i]
Expand Down Expand Up @@ -222,17 +224,3 @@ func (p *ChatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct,
},
}, nil
}

// headersToMap converts a [corev3.HeaderMap] to a Go map for easier processing.
func headersToMap(headers *corev3.HeaderMap) map[string]string {
// TODO: handle multiple headers with the same key.
hdrs := make(map[string]string)
for _, h := range headers.GetHeaders() {
if len(h.Value) > 0 {
hdrs[h.GetKey()] = h.Value
} else if utf8.Valid(h.RawValue) {
hdrs[h.GetKey()] = string(h.RawValue)
}
}
return hdrs
}
34 changes: 11 additions & 23 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@ import (
)

func TestChatCompletion_ProcessRequestHeaders(t *testing.T) {
p := &ChatCompletionProcessor{}
p := &chatCompletionProcessor{}
res, err := p.ProcessRequestHeaders(context.Background(), &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}},
})
require.NoError(t, err)
_, ok := res.Response.(*extprocv3.ProcessingResponse_RequestHeaders)
require.True(t, ok)
require.Equal(t, map[string]string{"foo": "bar"}, p.requestHeaders)
}

func TestChatCompletion_ProcessResponseHeaders(t *testing.T) {
t.Run("error translation", func(t *testing.T) {
mt := &mockTranslator{t: t, expHeaders: make(map[string]string)}
p := &ChatCompletionProcessor{translator: mt}
p := &chatCompletionProcessor{translator: mt}
mt.retErr = errors.New("test error")
_, err := p.ProcessResponseHeaders(context.Background(), nil)
require.ErrorContains(t, err, "test error")
Expand All @@ -42,7 +41,7 @@ func TestChatCompletion_ProcessResponseHeaders(t *testing.T) {
}
expHeaders := map[string]string{"foo": "bar", "dog": "cat"}
mt := &mockTranslator{t: t, expHeaders: expHeaders}
p := &ChatCompletionProcessor{translator: mt}
p := &chatCompletionProcessor{translator: mt}
res, err := p.ProcessResponseHeaders(context.Background(), inHeaders)
require.NoError(t, err)
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
Expand All @@ -53,7 +52,7 @@ func TestChatCompletion_ProcessResponseHeaders(t *testing.T) {
func TestChatCompletion_ProcessResponseBody(t *testing.T) {
t.Run("error translation", func(t *testing.T) {
mt := &mockTranslator{t: t}
p := &ChatCompletionProcessor{translator: mt}
p := &chatCompletionProcessor{translator: mt}
mt.retErr = errors.New("test error")
_, err := p.ProcessResponseBody(context.Background(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "test error")
Expand All @@ -72,7 +71,7 @@ func TestChatCompletion_ProcessResponseBody(t *testing.T) {
require.NoError(t, err)
celProgUint, err := llmcostcel.NewProgram("uint(9999)")
require.NoError(t, err)
p := &ChatCompletionProcessor{translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), config: &processorConfig{
p := &chatCompletionProcessor{translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), config: &processorConfig{
metadataNamespace: "ai_gateway_llm_ns",
requestCosts: []processorConfigRequestCost{
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}},
Expand Down Expand Up @@ -109,15 +108,15 @@ func TestChatCompletion_ProcessResponseBody(t *testing.T) {
func TestChatCompletion_ProcessRequestBody(t *testing.T) {
t.Run("body parser error", func(t *testing.T) {
rbp := mockRequestBodyParser{t: t, retErr: errors.New("test error")}
p := &ChatCompletionProcessor{config: &processorConfig{bodyParser: rbp.impl}}
p := &chatCompletionProcessor{config: &processorConfig{bodyParser: rbp.impl}}
_, err := p.ProcessRequestBody(context.Background(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to parse request body: test error")
})
t.Run("router error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo"}
rt := mockRouter{t: t, expHeaders: headers, retErr: errors.New("test error")}
p := &ChatCompletionProcessor{config: &processorConfig{bodyParser: rbp.impl, router: rt}, requestHeaders: headers, logger: slog.Default()}
p := &chatCompletionProcessor{config: &processorConfig{bodyParser: rbp.impl, router: rt}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(context.Background(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to calculate route: test error")
})
Expand All @@ -128,7 +127,7 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
p := &ChatCompletionProcessor{config: &processorConfig{
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: make(map[filterapi.VersionedAPISchema]translator.Factory),
}, requestHeaders: headers, logger: slog.Default()}
Expand All @@ -143,7 +142,7 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
factory := mockTranslatorFactory{t: t, retErr: errors.New("test error"), expPath: "/foo"}
p := &ChatCompletionProcessor{config: &processorConfig{
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
Expand All @@ -160,7 +159,7 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
factory := mockTranslatorFactory{t: t, retTranslator: mockTranslator{t: t, retErr: errors.New("test error")}, expPath: "/foo"}
p := &ChatCompletionProcessor{config: &processorConfig{
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
Expand All @@ -181,7 +180,7 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
bodyMut := &extprocv3.BodyMutation{}
mt := mockTranslator{t: t, expRequestBody: someBody, retHeaderMutation: headerMut, retBodyMutation: bodyMut}
factory := mockTranslatorFactory{t: t, retTranslator: mt, expPath: "/foo"}
p := &ChatCompletionProcessor{config: &processorConfig{
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
Expand All @@ -206,14 +205,3 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
require.Equal(t, "some-backend", string(hdrs[1].Header.RawValue))
})
}

func Test_headersToMap(t *testing.T) {
hm := &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{
{Key: "foo", Value: "bar"},
{Key: "dog", RawValue: []byte("cat")},
},
}
m := headersToMap(hm)
require.Equal(t, map[string]string{"foo": "bar", "dog": "cat"}, m)
}
18 changes: 9 additions & 9 deletions internal/extproc/models_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ import (
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
)

// ModelsProcessor implements [ProcessorIface] for the `/v1/models` endpoint.
// modelsProcessor implements [ProcessorIface] for the `/v1/models` endpoint.
// This processor returns an immediate response with the list of models that are declared in the filter
// configuration.
// Since it returns an immediate response after processing the headers, the rest of the methods of the
// ProcessorIface are not implemented. Those should never be called.
type ModelsProcessor struct {
type modelsProcessor struct {
logger *slog.Logger
models openai.ModelList
}

var _ ProcessorIface = (*ModelsProcessor)(nil)
var _ ProcessorIface = (*modelsProcessor)(nil)

// NewModelsProcessor creates a new processor that returns the list of declared models
func NewModelsProcessor(config *processorConfig, logger *slog.Logger) ProcessorIface {
func NewModelsProcessor(config *processorConfig, _ map[string]string, logger *slog.Logger) ProcessorIface {
models := openai.ModelList{
Object: "list",
Data: make([]openai.Model, 0, len(config.declaredModels)),
Expand All @@ -42,11 +42,11 @@ func NewModelsProcessor(config *processorConfig, logger *slog.Logger) ProcessorI
Created: openai.JSONUNIXTime(time.Now()), // TODO(nacx): does this really matter here?
})
}
return &ModelsProcessor{logger: logger, models: models}
return &modelsProcessor{logger: logger, models: models}
}

// ProcessRequestHeaders implements [ProcessorIface.ProcessRequestHeaders].
func (m *ModelsProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
func (m *modelsProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
m.logger.Info("Serving list of declared models")

body, err := json.Marshal(m.models)
Expand All @@ -73,17 +73,17 @@ func (m *ModelsProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.Hea
var errUnexpectedCall = errors.New("unexpected method call")

// ProcessRequestBody implements [ProcessorIface.ProcessRequestBody].
func (m *ModelsProcessor) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
func (m *modelsProcessor) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
return nil, fmt.Errorf("%w: ProcessRequestBody", errUnexpectedCall)
}

// ProcessResponseHeaders implements [ProcessorIface.ProcessResponseHeaders].
func (m *ModelsProcessor) ProcessResponseHeaders(context.Context, *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
func (m *modelsProcessor) ProcessResponseHeaders(context.Context, *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
return nil, fmt.Errorf("%w: ProcessResponseHeaders", errUnexpectedCall)
}

// ProcessResponseBody implements [ProcessorIface.ProcessResponseBody].
func (m *ModelsProcessor) ProcessResponseBody(context.Context, *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
func (m *modelsProcessor) ProcessResponseBody(context.Context, *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
return nil, fmt.Errorf("%w: ProcessResponseBody", errUnexpectedCall)
}

Expand Down
4 changes: 2 additions & 2 deletions internal/extproc/models_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

func TestModels_ProcessRequestHeaders(t *testing.T) {
cfg := &processorConfig{declaredModels: []string{"openai", "aws-bedrock"}}
p := NewModelsProcessor(cfg, slog.Default())
p := NewModelsProcessor(cfg, nil, slog.Default())
res, err := p.ProcessRequestHeaders(context.Background(), &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}},
})
Expand All @@ -43,7 +43,7 @@ func TestModels_ProcessRequestHeaders(t *testing.T) {
}

func TestModels_UnimplementedMethods(t *testing.T) {
p := &ModelsProcessor{}
p := &modelsProcessor{}
_, err := p.ProcessRequestBody(context.Background(), &extprocv3.HttpBody{})
require.ErrorIs(t, err, errUnexpectedCall)
_, err = p.ProcessResponseHeaders(context.Background(), &corev3.HeaderMap{})
Expand Down
2 changes: 1 addition & 1 deletion internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type processorConfigRequestCost struct {
}

// ProcessorFactory is the factory function used to create new instances of a processor.
type ProcessorFactory func(*processorConfig, *slog.Logger) ProcessorIface
type ProcessorFactory func(*processorConfig, map[string]string, *slog.Logger) ProcessorIface

// ProcessorIface is the interface for the processor.
// This decouples the processor implementation detail from the server implementation.
Expand Down
28 changes: 14 additions & 14 deletions internal/extproc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,13 @@ func (s *Server) Register(path string, newProcessor ProcessorFactory) {

// processorForPath returns the processor for the given path.
// Only exact path matching is supported currently
func (s *Server) processorForPath(path string) (ProcessorIface, error) {
func (s *Server) processorForPath(requestHeaders map[string]string) (ProcessorIface, error) {
path := requestHeaders[":path"]
newProcessor, ok := s.processors[path]
if !ok {
return nil, fmt.Errorf("no processor defined for path: %v", path)
}
return newProcessor(s.config, s.logger), nil
return newProcessor(s.config, requestHeaders, s.logger), nil
}

// Process implements [extprocv3.ExternalProcessorServer].
Expand Down Expand Up @@ -164,10 +165,9 @@ func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error
// If we're processing the request headers, read the :path header to instantiate the
// right processor.
if headers := req.GetRequestHeaders().GetHeaders(); headers != nil {
path := getHeader(headers, ":path")
p, err = s.processorForPath(path)
p, err = s.processorForPath(headersToMap(headers))
if err != nil {
s.logger.Error("cannot get processor for path", slog.String("path", path), slog.String("error", err.Error()))
s.logger.Error("cannot get processor", slog.String("error", err.Error()))
return status.Error(codes.NotFound, err.Error())
}
}
Expand Down Expand Up @@ -299,16 +299,16 @@ func filterSensitiveBody(resp *extprocv3.ProcessingResponse, logger *slog.Logger
return filteredResp
}

func getHeader(headers *corev3.HeaderMap, name string) string {
// headersToMap converts a [corev3.HeaderMap] to a Go map for easier processing.
func headersToMap(headers *corev3.HeaderMap) map[string]string {
// TODO: handle multiple headers with the same key.
hdrs := make(map[string]string)
for _, h := range headers.GetHeaders() {
if h.GetKey() == name {
if len(h.Value) > 0 {
return h.Value
} else if utf8.Valid(h.RawValue) {
return string(h.RawValue)
}
return ""
if len(h.Value) > 0 {
hdrs[h.GetKey()] = h.Value
} else if utf8.Valid(h.RawValue) {
hdrs[h.GetKey()] = string(h.RawValue)
}
}
return ""
return hdrs
}
17 changes: 14 additions & 3 deletions internal/extproc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func requireNewServerWithMockProcessor(t *testing.T) (*Server, *mockProcessor) {
s.config = &processorConfig{}

m := newMockProcessor(s.config, s.logger)
s.Register("/", func(*processorConfig, *slog.Logger) ProcessorIface { return m })
s.Register("/", func(*processorConfig, map[string]string, *slog.Logger) ProcessorIface { return m })

return s, m.(*mockProcessor)
}
Expand Down Expand Up @@ -264,11 +264,11 @@ func TestServer_ProcessorSelection(t *testing.T) {
require.NotNil(t, s)

s.config = &processorConfig{}
s.Register("/one", func(*processorConfig, *slog.Logger) ProcessorIface {
s.Register("/one", func(*processorConfig, map[string]string, *slog.Logger) ProcessorIface {
// Returning nil guarantees that the test will fail if this processor is selected
return nil
})
s.Register("/two", func(*processorConfig, *slog.Logger) ProcessorIface {
s.Register("/two", func(*processorConfig, map[string]string, *slog.Logger) ProcessorIface {
return &mockProcessor{
t: t,
expHeaderMap: &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: ":path", Value: "/two"}}},
Expand Down Expand Up @@ -361,3 +361,14 @@ func TestFilterSensitiveBody(t *testing.T) {
}
require.Contains(t, buf.String(), "filtering sensitive header")
}

func Test_headersToMap(t *testing.T) {
hm := &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{
{Key: "foo", Value: "bar"},
{Key: "dog", RawValue: []byte("cat")},
},
}
m := headersToMap(hm)
require.Equal(t, map[string]string{"foo": "bar", "dog": "cat"}, m)
}

0 comments on commit e3f0359

Please sign in to comment.