Skip to content

Commit

Permalink
extproc: custom processors per path and serve /v1/models (#325)
Browse files Browse the repository at this point in the history
**Commit Message**

extproc: custom processors per path and serve /v1/models

Refactors the server processing to allow registering custom Processors
for different request paths,
and adds a custom processor for requests to `/v1/models` that returns an
immediate response based
on the models that are configured in the filter configuration.

**Related Issues/PRs (if applicable)**

Related discussion: #186

---------

Signed-off-by: Ignasi Barrera <[email protected]>
  • Loading branch information
nacx authored Feb 12, 2025
1 parent 82b039f commit f07a7ff
Show file tree
Hide file tree
Showing 14 changed files with 748 additions and 363 deletions.
64 changes: 37 additions & 27 deletions cmd/extproc/mainlib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mainlib

import (
"context"
"errors"
"flag"
"fmt"
"log"
Expand All @@ -21,59 +22,66 @@ import (
"github.com/envoyproxy/ai-gateway/internal/version"
)

// parseAndValidateFlags parses and validates the flags passed to the external processor.
func parseAndValidateFlags(args []string) (configPath, addr string, logLevel slog.Level, err error) {
fs := flag.NewFlagSet("AI Gateway External Processor", flag.ContinueOnError)
configPathPtr := fs.String(
// extProcFlags is the struct that holds the flags passed to the external processor.
type extProcFlags struct {
configPath string // path to the configuration file.
extProcAddr string // gRPC address for the external processor.
logLevel slog.Level // log level for the external processor.
}

// parseAndValidateFlags parses and validates the flas passed to the external processor.
func parseAndValidateFlags(args []string) (extProcFlags, error) {
var (
flags extProcFlags
errs []error
fs = flag.NewFlagSet("AI Gateway External Processor", flag.ContinueOnError)
)

fs.StringVar(&flags.configPath,
"configPath",
"",
"path to the configuration file. The file must be in YAML format specified in filterapi.Config type. "+
"The configuration file is watched for changes.",
)
extProcAddrPtr := fs.String(
fs.StringVar(&flags.extProcAddr,
"extProcAddr",
":1063",
"gRPC address for the external processor. For example, :1063 or unix:///tmp/ext_proc.sock",
)
logLevelPtr := fs.String(
"logLevel",
"info", "log level for the external processor. One of 'debug', 'info', 'warn', or 'error'.",
"info",
"log level for the external processor. One of 'debug', 'info', 'warn', or 'error'.",
)

if err = fs.Parse(args); err != nil {
err = fmt.Errorf("failed to parse flags: %w", err)
return
if err := fs.Parse(args); err != nil {
return extProcFlags{}, fmt.Errorf("failed to parse extProcFlags: %w", err)
}

if *configPathPtr == "" {
err = fmt.Errorf("configPath must be provided")
return
if flags.configPath == "" {
errs = append(errs, fmt.Errorf("configPath must be provided"))
}

if err = logLevel.UnmarshalText([]byte(*logLevelPtr)); err != nil {
err = fmt.Errorf("failed to unmarshal log level: %w", err)
return
if err := flags.logLevel.UnmarshalText([]byte(*logLevelPtr)); err != nil {
errs = append(errs, fmt.Errorf("failed to unmarshal log level: %w", err))
}

configPath = *configPathPtr
addr = *extProcAddrPtr
return
return flags, errors.Join(errs...)
}

// Main is a main function for the external processor exposed
// for allowing users to build their own external processor.
func Main() {
configPath, extProcAddr, level, err := parseAndValidateFlags(os.Args[1:])
flags, err := parseAndValidateFlags(os.Args[1:])
if err != nil {
log.Fatalf("failed to parse and validate flags: %v", err)
log.Fatalf("failed to parse and validate extProcFlags: %v", err)
}

l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: flags.logLevel}))

l.Info("starting external processor",
slog.String("version", version.Version),
slog.String("address", extProcAddr),
slog.String("configPath", configPath),
slog.String("address", flags.extProcAddr),
slog.String("configPath", flags.configPath),
)

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -84,17 +92,19 @@ func Main() {
cancel()
}()

lis, err := net.Listen(listenAddress(extProcAddr))
lis, err := net.Listen(listenAddress(flags.extProcAddr))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}

server, err := extproc.NewServer[*extproc.Processor](l, extproc.NewProcessor)
server, err := extproc.NewServer(l)
if err != nil {
log.Fatalf("failed to create external processor server: %v", err)
}
server.Register("/v1/chat/completions", extproc.NewChatCompletionProcessor)
server.Register("/v1/models", extproc.NewModelsProcessor)

if err := extproc.StartConfigWatcher(ctx, configPath, server, l, time.Second*5); err != nil {
if err := extproc.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); err != nil {
log.Fatalf("failed to start config watcher: %v", err)
}

Expand Down
50 changes: 19 additions & 31 deletions cmd/extproc/mainlib/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_parseAndValidateFlags(t *testing.T) {
t.Run("ok flags", func(t *testing.T) {
t.Run("ok extProcFlags", func(t *testing.T) {
for _, tc := range []struct {
name string
args []string
Expand All @@ -17,7 +18,7 @@ func Test_parseAndValidateFlags(t *testing.T) {
logLevel slog.Level
}{
{
name: "minimal flags",
name: "minimal extProcFlags",
args: []string{"-configPath", "/path/to/config.yaml"},
configPath: "/path/to/config.yaml",
addr: ":1063",
Expand Down Expand Up @@ -52,44 +53,31 @@ func Test_parseAndValidateFlags(t *testing.T) {
logLevel: slog.LevelError,
},
{
name: "all flags",
args: []string{"-configPath", "/path/to/config.yaml", "-extProcAddr", "unix:///tmp/ext_proc.sock", "-logLevel", "debug"},
name: "all extProcFlags",
args: []string{
"-configPath", "/path/to/config.yaml",
"-extProcAddr", "unix:///tmp/ext_proc.sock",
"-logLevel", "debug",
},
configPath: "/path/to/config.yaml",
addr: "unix:///tmp/ext_proc.sock",
logLevel: slog.LevelDebug,
},
} {
t.Run(tc.name, func(t *testing.T) {
configPath, addr, logLevel, err := parseAndValidateFlags(tc.args)
assert.Equal(t, tc.configPath, configPath)
assert.Equal(t, tc.addr, addr)
assert.Equal(t, tc.logLevel, logLevel)
assert.NoError(t, err)
flags, err := parseAndValidateFlags(tc.args)
require.NoError(t, err)
assert.Equal(t, tc.configPath, flags.configPath)
assert.Equal(t, tc.addr, flags.extProcAddr)
assert.Equal(t, tc.logLevel, flags.logLevel)
})
}
})
t.Run("invalid flags", func(t *testing.T) {
for _, tc := range []struct {
name string
flags []string
expErr string
}{
{
name: "missing configPath",
flags: []string{"-extProcAddr", ":1063"},
expErr: "configPath must be provided",
},
{
name: "invalid logLevel",
flags: []string{"-configPath", "/path/to/config.yaml", "-logLevel", "invalid"},
expErr: `failed to unmarshal log level: slog: level string "invalid": unknown name`,
},
} {
t.Run(tc.name, func(t *testing.T) {
_, _, _, err := parseAndValidateFlags(tc.flags)
assert.EqualError(t, err, tc.expErr)
})
}

t.Run("invalid extProcFlags", func(t *testing.T) {
_, err := parseAndValidateFlags([]string{"-logLevel", "invalid"})
assert.EqualError(t, err, `configPath must be provided
failed to unmarshal log level: slog: level string "invalid": unknown name`)
})
}

Expand Down
2 changes: 1 addition & 1 deletion filterapi/filterconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

func TestDefaultConfig(t *testing.T) {
server, err := extproc.NewServer(slog.Default(), extproc.NewProcessor)
server, err := extproc.NewServer(slog.Default())
require.NoError(t, err)
require.NotNil(t, server)

Expand Down
42 changes: 42 additions & 0 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package openai
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -723,3 +725,43 @@ type ErrorType struct {
// The event_id of the client event that caused the error, if applicable.
EventID *string `json:"event_id,omitempty"`
}

// ModelList is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/models/list
type ModelList struct {
// Data is a list of models.
Data []Model `json:"data"`
// Object is the object type, which is always "list".
Object string `json:"object"`
}

// Model is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/models/object
type Model struct {
// ID is the model identifier, which can be referenced in the API endpoints.
ID string `json:"id"`
// Created is the Unix timestamp (in seconds) when the model was created.
Created JSONUNIXTime `json:"created"`
// Object is the object type, which is always "model".
Object string `json:"object"`
// OwnedBy is the organization that owns the model.
OwnedBy string `json:"owned_by"`
}

// JSONUNIXTime is a helper type to marshal/unmarshal time.Time UNIX timestamps.
type JSONUNIXTime time.Time

// MarshalJSON implements [json.Marshaler].
func (t JSONUNIXTime) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(time.Time(t).Unix(), 10)), nil
}

// UnmarshalJSON implements [json.Unmarshaler].
func (t *JSONUNIXTime) UnmarshalJSON(s []byte) error {
q, err := strconv.ParseInt(string(s), 10, 64)
if err != nil {
return err
}
*(*time.Time)(t) = time.Unix(q, 0)
return nil
}
28 changes: 28 additions & 0 deletions internal/apischema/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"encoding/json"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/openai/openai-go"
Expand Down Expand Up @@ -234,3 +235,30 @@ func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) {
})
}
}

func TestModelListMarshal(t *testing.T) {
var (
model = Model{
ID: "gpt-3.5-turbo",
Object: "model",
OwnedBy: "tetrate",
Created: JSONUNIXTime(time.Date(2025, 0o1, 0o1, 0, 0, 0, 0, time.UTC)),
}
list = ModelList{Object: "list", Data: []Model{model}}
raw = `{"object":"list","data":[{"id":"gpt-3.5-turbo","object":"model","owned_by":"tetrate","created":1735689600}]}`
)

b, err := json.Marshal(list)
require.NoError(t, err)
require.JSONEq(t, raw, string(b))

var out ModelList
require.NoError(t, json.Unmarshal([]byte(raw), &out))
require.Len(t, out.Data, 1)
require.Equal(t, "list", out.Object)
require.Equal(t, model.ID, out.Data[0].ID)
require.Equal(t, model.Object, out.Data[0].Object)
require.Equal(t, model.OwnedBy, out.Data[0].OwnedBy)
// Unmarshalling initializes other fields in time.Time we're not interested with. Just compare the actual time.
require.Equal(t, time.Time(model.Created).Unix(), time.Time(out.Data[0].Created).Unix())
}
Loading

0 comments on commit f07a7ff

Please sign in to comment.