Skip to content

Commit

Permalink
extproc: handle requests to /v1/models based on the declared models i…
Browse files Browse the repository at this point in the history
…n the filter config

Signed-off-by: Ignasi Barrera <[email protected]>
  • Loading branch information
nacx committed Feb 11, 2025
1 parent ea541c5 commit f0c290c
Show file tree
Hide file tree
Showing 13 changed files with 672 additions and 295 deletions.
4 changes: 3 additions & 1 deletion cmd/extproc/mainlib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ func Main() {
log.Fatalf("failed to listen: %v", err)
}

server, err := extproc.NewServer[*extproc.Processor](l, extproc.NewProcessor)
server, err := extproc.NewServer(l)
if err != nil {
log.Fatalf("failed to create external processor server: %v", err)
}
server.Register("/v1/chat/completions", extproc.NewChatCompletionProcessor)
server.Register("/v1/models", extproc.NewModelsProcessor)

if err := extproc.StartConfigWatcher(ctx, configPath, server, l, time.Second*5); err != nil {
log.Fatalf("failed to start config watcher: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion filterapi/filterconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

func TestDefaultConfig(t *testing.T) {
server, err := extproc.NewServer(slog.Default(), extproc.NewProcessor)
server, err := extproc.NewServer(slog.Default())
require.NoError(t, err)
require.NotNil(t, server)

Expand Down
40 changes: 40 additions & 0 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package openai
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -723,3 +725,41 @@ type ErrorType struct {
// The event_id of the client event that caused the error, if applicable.
EventID *string `json:"event_id,omitempty"`
}

// ModelList is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/models/list
type ModelList struct {
// Data is a list of models.
Data []Model `json:"data"`
// Object is the object type, which is always "list".
Object string `json:"object"`
}

// Model is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/models/object
type Model struct {
// ID is the model identifier, which can be referenced in the API endpoints.
ID string `json:"id"`
// Created is the Unix timestamp (in seconds) when the model was created.
Created JSONUNIXTime `json:"created"`
// Object is the object type, which is always "model".
Object string `json:"object"`
// OwnedBy is the organization that owns the model.
OwnedBy string `json:"owned_by"`
}

// JSONUNIXTime is a helper type to marshal/unmarshal time.Time UNIX timestamps.
type JSONUNIXTime time.Time

func (t JSONUNIXTime) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(time.Time(t).Unix(), 10)), nil
}

func (t *JSONUNIXTime) UnmarshalJSON(s []byte) error {
q, err := strconv.ParseInt(string(s), 10, 64)
if err != nil {
return err
}
*(*time.Time)(t) = time.Unix(q, 0)
return nil
}
28 changes: 28 additions & 0 deletions internal/apischema/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"encoding/json"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/openai/openai-go"
Expand Down Expand Up @@ -234,3 +235,30 @@ func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) {
})
}
}

func TestModelListMarshal(t *testing.T) {
var (
model = Model{
ID: "gpt-3.5-turbo",
Object: "model",
OwnedBy: "tetrate",
Created: JSONUNIXTime(time.Date(2025, 0o1, 0o1, 0, 0, 0, 0, time.UTC)),
}
list = ModelList{Object: "list", Data: []Model{model}}
raw = `{"object":"list","data":[{"id":"gpt-3.5-turbo","object":"model","owned_by":"tetrate","created":1735689600}]}`
)

b, err := json.Marshal(list)
require.NoError(t, err)
require.JSONEq(t, raw, string(b))

var out ModelList
require.NoError(t, json.Unmarshal([]byte(raw), &out))
require.Len(t, out.Data, 1)
require.Equal(t, "list", out.Object)
require.Equal(t, model.ID, out.Data[0].ID)
require.Equal(t, model.Object, out.Data[0].Object)
require.Equal(t, model.OwnedBy, out.Data[0].OwnedBy)
// Unmarshalling initializes other fields in time.Time we're not interested with. Just compare the actual time
require.Equal(t, time.Time(model.Created).Unix(), time.Time(out.Data[0].Created).Unix())
}
238 changes: 238 additions & 0 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
package extproc

import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"log/slog"
"unicode/utf8"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"google.golang.org/protobuf/types/known/structpb"

"github.com/envoyproxy/ai-gateway/filterapi"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

// NewChatCompletionProcessor creates a new processor for the chat completino requests
func NewChatCompletionProcessor(config *processorConfig, logger *slog.Logger) ProcessorIface {
return &ChatCompletionProcessor{config: config, logger: logger}
}

// ChatCompletionProcessor handles the processing of the request and response messages for a single stream.
type ChatCompletionProcessor struct {
logger *slog.Logger
config *processorConfig
requestHeaders map[string]string
responseHeaders map[string]string
responseEncoding string
translator translator.Translator
// cost is the cost of the request that is accumulated during the processing of the response.
costs translator.LLMTokenUsage
}

// ProcessRequestHeaders implements [ProcessorIface.ProcessRequestHeaders].
func (p *ChatCompletionProcessor) ProcessRequestHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
p.requestHeaders = headersToMap(headers)
resp := &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{},
}}
return resp, nil
}

// ProcessRequestBody implements [ProcessorIface.ProcessRequestBody].
func (p *ChatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := p.requestHeaders[":path"]
model, body, err := p.config.bodyParser(path, rawBody)
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
p.logger.Info("Processing request", "path", path, "model", model)

p.requestHeaders[p.config.modelNameHeaderKey] = model
b, err := p.config.router.Calculate(p.requestHeaders)
if err != nil {
return nil, fmt.Errorf("failed to calculate route: %w", err)
}
p.logger.Info("Selected backend", "backend", b.Name)

factory, ok := p.config.factories[b.Schema]
if !ok {
return nil, fmt.Errorf("failed to find factory for output schema %q", b.Schema)
}

t, err := factory(path)
if err != nil {
return nil, fmt.Errorf("failed to create translator: %w", err)
}
p.translator = t

headerMutation, bodyMutation, override, err := p.translator.RequestBody(body)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
}

if headerMutation == nil {
headerMutation = &extprocv3.HeaderMutation{}
}
// Set the model name to the request header with the key `x-ai-gateway-llm-model-name`.
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.modelNameHeaderKey, RawValue: []byte(model)},
}, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.selectedBackendHeaderKey, RawValue: []byte(b.Name)},
})

if authHandler, ok := p.config.backendAuthHandlers[b.Name]; ok {
if err := authHandler.Do(ctx, p.requestHeaders, headerMutation, bodyMutation); err != nil {
return nil, fmt.Errorf("failed to do auth request: %w", err)
}
}

resp := &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_RequestBody{
RequestBody: &extprocv3.BodyResponse{
Response: &extprocv3.CommonResponse{
HeaderMutation: headerMutation,
BodyMutation: bodyMutation,
ClearRouteCache: true,
},
},
},
ModeOverride: override,
}
return resp, nil
}

// ProcessResponseHeaders implements [ProcessorIface.ProcessResponseHeaders].
func (p *ChatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
p.responseHeaders = headersToMap(headers)
if enc := p.responseHeaders["content-encoding"]; enc != "" {
p.responseEncoding = enc
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extprocv3.HeadersResponse{},
}}, nil
}
headerMutation, err := p.translator.ResponseHeaders(p.responseHeaders)
if err != nil {
return nil, fmt.Errorf("failed to transform response headers: %w", err)
}
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extprocv3.HeadersResponse{
Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation},
},
}}, nil
}

// ProcessResponseBody implements [ProcessorIface.ProcessResponseBody].
func (p *ChatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
var br io.Reader
switch p.responseEncoding {
case "gzip":
br, err = gzip.NewReader(bytes.NewReader(body.Body))
if err != nil {
return nil, fmt.Errorf("failed to decode gzip: %w", err)
}
default:
br = bytes.NewReader(body.Body)
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseBody{}}, nil
}

headerMutation, bodyMutation, tokenUsage, err := p.translator.ResponseBody(p.responseHeaders, br, body.EndOfStream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}

resp := &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ResponseBody{
ResponseBody: &extprocv3.BodyResponse{
Response: &extprocv3.CommonResponse{
HeaderMutation: headerMutation,
BodyMutation: bodyMutation,
},
},
},
}

// TODO: this is coupled with "LLM" specific logic. Once we have another use case, we need to refactor this.
p.costs.InputTokens += tokenUsage.InputTokens
p.costs.OutputTokens += tokenUsage.OutputTokens
p.costs.TotalTokens += tokenUsage.TotalTokens
if body.EndOfStream && len(p.config.requestCosts) > 0 {
resp.DynamicMetadata, err = p.maybeBuildDynamicMetadata()
if err != nil {
return nil, fmt.Errorf("failed to build dynamic metadata: %w", err)
}
}
return resp, nil
}

func (p *ChatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(p.config.requestCosts))
for i := range p.config.requestCosts {
c := &p.config.requestCosts[i]
var cost uint32
switch c.Type {
case filterapi.LLMRequestCostTypeInputToken:
cost = p.costs.InputTokens
case filterapi.LLMRequestCostTypeOutputToken:
cost = p.costs.OutputTokens
case filterapi.LLMRequestCostTypeTotalToken:
cost = p.costs.TotalTokens
case filterapi.LLMRequestCostTypeCELExpression:
costU64, err := llmcostcel.EvaluateProgram(
c.celProg,
p.requestHeaders[p.config.modelNameHeaderKey],
p.requestHeaders[p.config.selectedBackendHeaderKey],
p.costs.InputTokens,
p.costs.OutputTokens,
p.costs.TotalTokens,
)
if err != nil {
return nil, fmt.Errorf("failed to evaluate CEL expression: %w", err)
}
cost = uint32(costU64) //nolint:gosec
default:
return nil, fmt.Errorf("unknown request cost kind: %s", c.Type)
}
p.logger.Info("Setting request cost metadata", "type", c.Type, "cost", cost, "metadataKey", c.MetadataKey)
metadata[c.MetadataKey] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(cost)}}
}
if len(metadata) == 0 {
return nil, nil
}
return &structpb.Struct{
Fields: map[string]*structpb.Value{
p.config.metadataNamespace: {
Kind: &structpb.Value_StructValue{
StructValue: &structpb.Struct{Fields: metadata},
},
},
},
}, nil
}

// headersToMap converts a [corev3.HeaderMap] to a Go map for easier processing.
func headersToMap(headers *corev3.HeaderMap) map[string]string {
// TODO: handle multiple headers with the same key.
hdrs := make(map[string]string)
for _, h := range headers.GetHeaders() {
if len(h.Value) > 0 {
hdrs[h.GetKey()] = h.Value
} else if utf8.Valid(h.RawValue) {
hdrs[h.GetKey()] = string(h.RawValue)
}
}
return hdrs
}
Loading

0 comments on commit f0c290c

Please sign in to comment.