Skip to content

Commit

Permalink
extproc: decouple router package from api paths (envoyproxy#352)
Browse files Browse the repository at this point in the history
**Commit Message**

extproc: decouple router package from api paths

This is a follow-up change to decouple the router package from the API
paths. Now that we have specialized processors per API path it does not
make sense for the router package to have path-related logic.
Now each processor is responsible for instantiating the right body
parser for the path they're processing.

**Related Issues/PRs (if applicable)**

Follow-up for envoyproxy#325 and
envoyproxy#334

**Special notes for reviewers (if applicable)**

N/A

---------

Signed-off-by: Ignasi Barrera <[email protected]>
Co-authored-by: Takeshi Yoneda <[email protected]>
Signed-off-by: Loong <[email protected]>
  • Loading branch information
2 people authored and daixiang0 committed Feb 19, 2025
1 parent 03c0fb1 commit 3f72b95
Show file tree
Hide file tree
Showing 17 changed files with 115 additions and 188 deletions.
22 changes: 17 additions & 5 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
Expand All @@ -18,17 +19,21 @@ import (
"google.golang.org/protobuf/types/known/structpb"

"github.com/envoyproxy/ai-gateway/filterapi"
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

// NewChatCompletionProcessor implements [Processor] for the /chat/completions endpoint.
func NewChatCompletionProcessor(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger) Processor {
func NewChatCompletionProcessor(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger) (Processor, error) {
if config.schema.Name != filterapi.APISchemaOpenAI {
return nil, fmt.Errorf("unsupported API schema: %s", config.schema.Name)
}
return &chatCompletionProcessor{
config: config,
requestHeaders: requestHeaders,
logger: logger,
}
}, nil
}

// chatCompletionProcessor handles the processing of the request and response messages for a single stream.
Expand Down Expand Up @@ -70,12 +75,11 @@ func (c *chatCompletionProcessor) ProcessRequestHeaders(_ context.Context, _ *co

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (c *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := c.requestHeaders[":path"]
model, body, err := c.config.bodyParser(path, rawBody)
model, body, err := parseOpenAIChatCompletionBody(rawBody)
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
c.logger.Info("Processing request", "path", path, "model", model)
c.logger.Info("Processing request", "path", c.requestHeaders[":path"], "model", model)

c.requestHeaders[c.config.modelNameHeaderKey] = model
b, err := c.config.router.Calculate(c.requestHeaders)
Expand Down Expand Up @@ -195,6 +199,14 @@ func (c *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *e
return resp, nil
}

func parseOpenAIChatCompletionBody(body *extprocv3.HttpBody) (modelName string, rb translator.RequestBody, err error) {
var openAIReq openai.ChatCompletionRequest
if err := json.Unmarshal(body.Body, &openAIReq); err != nil {
return "", nil, fmt.Errorf("failed to unmarshal body: %w", err)
}
return openAIReq.Model, &openAIReq, nil
}

func (c *chatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(c.config.requestCosts))
for i := range c.config.requestCosts {
Expand Down
88 changes: 64 additions & 24 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package extproc

import (
"encoding/json"
"errors"
"io"
"log/slog"
Expand All @@ -16,11 +17,24 @@ import (
"github.com/stretchr/testify/require"

"github.com/envoyproxy/ai-gateway/filterapi"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

func TestChatCompletion_Scheema(t *testing.T) {
t.Run("unsupported", func(t *testing.T) {
cfg := &processorConfig{schema: filterapi.VersionedAPISchema{Name: "Foo", Version: "v123"}}
_, err := NewChatCompletionProcessor(cfg, nil, nil)
require.ErrorContains(t, err, "unsupported API schema: Foo")
})
t.Run("supported openai", func(t *testing.T) {
cfg := &processorConfig{schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v123"}}
_, err := NewChatCompletionProcessor(cfg, nil, nil)
require.NoError(t, err)
})
}

func TestChatCompletion_SelectTranslator(t *testing.T) {
c := &chatCompletionProcessor{}
t.Run("unsupported", func(t *testing.T) {
Expand Down Expand Up @@ -128,64 +142,71 @@ func TestChatCompletion_ProcessResponseBody(t *testing.T) {
}

func TestChatCompletion_ProcessRequestBody(t *testing.T) {
bodyFromModel := func(t *testing.T, model string) []byte {
var openAIReq openai.ChatCompletionRequest
openAIReq.Model = model
bytes, err := json.Marshal(openAIReq)
require.NoError(t, err)
return bytes
}
t.Run("body parser error", func(t *testing.T) {
rbp := mockRequestBodyParser{t: t, retErr: errors.New("test error")}
p := &chatCompletionProcessor{config: &processorConfig{bodyParser: rbp.impl}}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to parse request body: test error")
p := &chatCompletionProcessor{}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: []byte("nonjson")})
require.ErrorContains(t, err, "invalid character 'o' in literal null")
})
t.Run("router error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo"}
rt := mockRouter{t: t, expHeaders: headers, retErr: errors.New("test error")}
p := &chatCompletionProcessor{config: &processorConfig{bodyParser: rbp.impl, router: rt}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
p := &chatCompletionProcessor{config: &processorConfig{router: rt}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: bodyFromModel(t, "some-model")})
require.ErrorContains(t, err, "failed to calculate route: test error")
})
t.Run("translator not found", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo"}
rt := mockRouter{
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
p := &chatCompletionProcessor{config: &processorConfig{router: rt}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: bodyFromModel(t, "some-model")})
require.ErrorContains(t, err, "unsupported API schema: backend={some-schema v10.0}")
})
t.Run("translator error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo"}
someBody := bodyFromModel(t, "some-model")
rt := mockRouter{
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
tr := mockTranslator{t: t, retErr: errors.New("test error")}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
}, requestHeaders: headers, logger: slog.Default(), translator: tr}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
var body openai.ChatCompletionRequest
require.NoError(t, json.Unmarshal(someBody, &body))
tr := mockTranslator{t: t, retErr: errors.New("test error"), expRequestBody: &body}
p := &chatCompletionProcessor{
config: &processorConfig{router: rt},
requestHeaders: headers, logger: slog.Default(), translator: tr,
}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody})
require.ErrorContains(t, err, "failed to transform request: test error")
})
t.Run("ok", func(t *testing.T) {
someBody := router.RequestBody("foooooooooooooo")
someBody := bodyFromModel(t, "some-model")
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo", retRb: someBody}
rt := mockRouter{
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
headerMut := &extprocv3.HeaderMutation{}
bodyMut := &extprocv3.BodyMutation{}
mt := mockTranslator{t: t, expRequestBody: someBody, retHeaderMutation: headerMut, retBodyMutation: bodyMut}

var expBody openai.ChatCompletionRequest
require.NoError(t, json.Unmarshal(someBody, &expBody))
mt := mockTranslator{t: t, expRequestBody: &expBody, retHeaderMutation: headerMut, retBodyMutation: bodyMut}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
router: rt,
selectedBackendHeaderKey: "x-ai-gateway-backend-key",
modelNameHeaderKey: "x-ai-gateway-model-key",
}, requestHeaders: headers, logger: slog.Default(), translator: mt}
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody})
require.NoError(t, err)
require.Equal(t, mt, p.translator)
require.NotNil(t, resp)
Expand All @@ -202,3 +223,22 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
require.Equal(t, "some-backend", string(hdrs[1].Header.RawValue))
})
}

func TestChatCompletion_ParseBody(t *testing.T) {
t.Run("ok", func(t *testing.T) {
original := openai.ChatCompletionRequest{Model: "llama3.3"}
bytes, err := json.Marshal(original)
require.NoError(t, err)

modelName, rb, err := parseOpenAIChatCompletionBody(&extprocv3.HttpBody{Body: bytes})
require.NoError(t, err)
require.Equal(t, "llama3.3", modelName)
require.NotNil(t, rb)
})
t.Run("error", func(t *testing.T) {
modelName, rb, err := parseOpenAIChatCompletionBody(&extprocv3.HttpBody{})
require.Error(t, err)
require.Equal(t, "", modelName)
require.Nil(t, rb)
})
}
22 changes: 2 additions & 20 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"github.com/envoyproxy/ai-gateway/filterapi"
"github.com/envoyproxy/ai-gateway/filterapi/x"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
)

Expand Down Expand Up @@ -70,7 +69,7 @@ func (m mockProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.Ht
type mockTranslator struct {
t *testing.T
expHeaders map[string]string
expRequestBody router.RequestBody
expRequestBody translator.RequestBody
expResponseBody *extprocv3.HttpBody
retHeaderMutation *extprocv3.HeaderMutation
retBodyMutation *extprocv3.BodyMutation
Expand All @@ -80,7 +79,7 @@ type mockTranslator struct {
}

// RequestBody implements [translator.Translator.RequestBody].
func (m mockTranslator) RequestBody(body router.RequestBody) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, override *extprocv3http.ProcessingMode, err error) {
func (m mockTranslator) RequestBody(body translator.RequestBody) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, override *extprocv3http.ProcessingMode, err error) {
require.Equal(m.t, m.expRequestBody, body)
return m.retHeaderMutation, m.retBodyMutation, m.retOverride, m.retErr
}
Expand Down Expand Up @@ -127,23 +126,6 @@ func (m mockRouter) Calculate(headers map[string]string) (*filterapi.Backend, er
return b, m.retErr
}

// mockRequestBodyParser implements [router.RequestBodyParser] for testing.
type mockRequestBodyParser struct {
t *testing.T
expPath string
expBody []byte
retModelName string
retRb router.RequestBody
retErr error
}

// impl implements [router.RequestBodyParser].
func (m *mockRequestBodyParser) impl(path string, body *extprocv3.HttpBody) (modelName string, rb router.RequestBody, err error) {
require.Equal(m.t, m.expPath, path)
require.Equal(m.t, m.expBody, body.Body)
return m.retModelName, m.retRb, m.retErr
}

// mockExternalProcessingStream implements [extprocv3.ExternalProcessor_ProcessServer] for testing.
type mockExternalProcessingStream struct {
t *testing.T
Expand Down
4 changes: 2 additions & 2 deletions internal/extproc/models_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type modelsProcessor struct {
var _ Processor = (*modelsProcessor)(nil)

// NewModelsProcessor creates a new processor that returns the list of declared models
func NewModelsProcessor(config *processorConfig, _ map[string]string, logger *slog.Logger) Processor {
func NewModelsProcessor(config *processorConfig, _ map[string]string, logger *slog.Logger) (Processor, error) {
models := openai.ModelList{
Object: "list",
Data: make([]openai.Model, 0, len(config.declaredModels)),
Expand All @@ -47,7 +47,7 @@ func NewModelsProcessor(config *processorConfig, _ map[string]string, logger *sl
Created: openai.JSONUNIXTime(time.Now()), // TODO(nacx): does this really matter here?
})
}
return &modelsProcessor{logger: logger, models: models}
return &modelsProcessor{logger: logger, models: models}, nil
}

// ProcessRequestHeaders implements [Processor.ProcessRequestHeaders].
Expand Down
3 changes: 2 additions & 1 deletion internal/extproc/models_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import (

func TestModels_ProcessRequestHeaders(t *testing.T) {
cfg := &processorConfig{declaredModels: []string{"openai", "aws-bedrock"}}
p := NewModelsProcessor(cfg, nil, slog.Default())
p, err := NewModelsProcessor(cfg, nil, slog.Default())
require.NoError(t, err)
res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}},
})
Expand Down
5 changes: 2 additions & 3 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ import (
"github.com/envoyproxy/ai-gateway/filterapi"
"github.com/envoyproxy/ai-gateway/filterapi/x"
"github.com/envoyproxy/ai-gateway/internal/extproc/backendauth"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
)

// processorConfig is the configuration for the processor.
// This will be created by the server and passed to the processor when it detects a new configuration.
type processorConfig struct {
uuid string
bodyParser router.RequestBodyParser
schema filterapi.VersionedAPISchema
router x.Router
modelNameHeaderKey, selectedBackendHeaderKey string
backendAuthHandlers map[string]backendauth.Handler
Expand All @@ -39,7 +38,7 @@ type processorConfigRequestCost struct {
}

// ProcessorFactory is the factory function used to create new instances of a processor.
type ProcessorFactory func(*processorConfig, map[string]string, *slog.Logger) Processor
type ProcessorFactory func(*processorConfig, map[string]string, *slog.Logger) (Processor, error)

// Processor is the interface for the processor.
// This decouples the processor implementation detail from the server implementation.
Expand Down
42 changes: 0 additions & 42 deletions internal/extproc/router/request_body.go

This file was deleted.

Loading

0 comments on commit 3f72b95

Please sign in to comment.